Compare commits

...

219 Commits

Author SHA1 Message Date
Luis Pater
938af75954 Merge branch 'router-for-me:main' into main 2026-04-09 21:14:30 +08:00
Luis Pater
1dba2d0f81 fix(handlers): add base URL validation and improve API key deletion tests 2026-04-09 20:51:54 +08:00
Luis Pater
730809d8ea fix(auth): preserve and restore ready view cursors during index rebuilds 2026-04-09 20:26:16 +08:00
Luis Pater
5e81b65f2f fix(auth, executor): normalize Qwen base URL, adjust RefreshLead duration, and add tests 2026-04-09 18:07:07 +08:00
Supra4E8C
c42480a574 Merge pull request #501 from Ve-ria/main
feat: add glm-5.1 to CodeBuddy model list
2026-04-09 14:37:28 +08:00
rensumo
55c146a0e7 feat: add glm-5.1 to CodeBuddy model list 2026-04-09 14:20:26 +08:00
Luis Pater
ad8e3964ff fix(auth): add retry logic for 429 status with Retry-After and improve testing 2026-04-09 07:07:19 +08:00
Luis Pater
e9dc576409 Merge branch 'router-for-me:main' into main 2026-04-09 03:49:09 +08:00
Luis Pater
941334da79 fix(auth): handle OAuth model alias in retry logic and refine Qwen quota handling 2026-04-09 03:44:19 +08:00
Luis Pater
d54f816363 fix(executor): update Qwen user agent and enhance header configuration 2026-04-09 01:45:52 +08:00
Luis Pater
f43d25def1 Merge pull request #496 from kunish/fix/copilot-premium-request-inflation
fix(copilot): prevent intermittent context overflow for Claude models
2026-04-08 23:43:15 +08:00
Luis Pater
a279192881 Merge pull request #498 from router-for-me/plus
v6.9.17
2026-04-08 23:42:40 +08:00
Luis Pater
6a43d7285c Merge branch 'main' into plus 2026-04-08 23:42:05 +08:00
kunish
578c312660 fix(copilot): lower static Claude context limits and expose them to Claude Code
The Copilot API enforces per-account prompt token limits (128K individual,
168K business) that are lower than the total context window (200K). When
the dynamic /models API fetch fails or returns no capabilities.limits,
the static fallback of 200K exceeds the real enforced limit, causing
intermittent "prompt token count exceeds the limit" errors.

Two complementary fixes:

1. Lower static Copilot Claude model ContextLength from 200000 to 128000
   (the conservative default matching defaultCopilotContextLength). Dynamic
   API limits override this when available.

2. Add context_length and max_completion_tokens to Claude-format model
   responses so Claude Code CLI can learn the actual Copilot limit instead
   of relying on its built-in 1M context configuration.
2026-04-08 17:02:53 +08:00
Supra4E8C
6bb9bf3132 Merge pull request #495 from Ve-ria/main
feat(codebuddy): 新增 glm-5v-turbo 模型并更新上下文长度
2026-04-08 14:27:43 +08:00
hkfires
343a2fc2f7 docs: update AGENTS.md for improved clarity and detail in commands and architecture 2026-04-08 12:33:16 +08:00
Luis Pater
12b967118b Merge pull request #2592 from router-for-me/tests
fix(tests): update test cases
2026-04-08 11:57:15 +08:00
Luis Pater
70efd4e016 chore: add workflow to retarget main PRs to dev automatically 2026-04-08 10:35:49 +08:00
Luis Pater
f5aa68ecda chore: add workflow to prevent AGENTS.md modifications in pull requests 2026-04-08 10:12:51 +08:00
rensumo
9a5f142c33 feat(codebuddy): add glm-5v-turbo model and update context lengths 2026-04-08 09:48:25 +08:00
hkfires
d390b95b76 fix(tests): update test cases 2026-04-08 08:53:50 +08:00
Luis Pater
d1f6224b70 Merge pull request #2569 from LucasInsight/fix/record-zero-usage
fix: record zero usage
2026-04-08 08:13:11 +08:00
Luis Pater
fcc59d606d fix(translator): add unit tests to validate output_item.done fallback logic for Gemini and Claude 2026-04-08 03:54:15 +08:00
Luis Pater
91e7591955 fix(executor): add transient 429 resource exhausted handling with retry logic 2026-04-08 02:48:53 +08:00
Luis Pater
4607356333 Merge pull request #491 from Ve-ria/main
修复 CodeBuddy 不支持非流式请求的问题
2026-04-07 18:25:21 +08:00
Luis Pater
9a9ed99072 Merge pull request #494 from router-for-me/plus
v6.9.16
2026-04-07 18:23:51 +08:00
Luis Pater
5ae38584b8 Merge branch 'main' into plus 2026-04-07 18:23:31 +08:00
Luis Pater
c8b7e2b8d6 fix(executor): ensure empty stream completions use output_item.done as fallback
Fixed: #2583
2026-04-07 18:21:12 +08:00
Luis Pater
cad45ffa33 Merge pull request #2578 from LemonZuo/feat_socks5h
feat: support socks5h scheme for proxy settings
2026-04-07 09:57:18 +08:00
Luis Pater
6a27bceec0 Merge pull request #2576 from zilianpn/fix/disable-cooling-auth-errors
fix(auth): honor disable-cooling and enrich no-auth errors
2026-04-07 09:56:25 +08:00
Lemon
163d68318f feat: support socks5h scheme for proxy settings 2026-04-07 07:46:11 +08:00
zilianpn
0ea768011b fix(auth): honor disable-cooling and enrich no-auth errors 2026-04-07 01:12:13 +08:00
Michael
8b9dbe10f0 fix: record zero usage 2026-04-06 20:19:42 +08:00
rensumo
341b4beea1 Update internal/runtime/executor/codebuddy_executor.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-04-06 14:16:56 +08:00
rensumo
bea13f9724 fix(executor): support non-stream requests for CodeBuddy 2026-04-06 13:59:06 +08:00
Luis Pater
9f5bdfaa31 Merge pull request #2531 from jamestut/openai-vertex-token-usage-fix
Fix missing `response.completed.usage` for late-usage OpenAI-compatible streams
2026-04-06 09:30:49 +08:00
Luis Pater
9eabdd09db Merge pull request #2522 from aikins01/fix/strip-tool-use-signature
fix(amp): strip signature from tool_use blocks before forwarding to Claude
2026-04-06 09:30:14 +08:00
Luis Pater
c3f8dc362e Merge pull request #2491 from mpfo0106/feature/claude-code-safe-alignment-sentinels
test(claude): add compatibility sentinels and centralize builtin fallback handling
2026-04-06 09:27:08 +08:00
Luis Pater
b85120873b Merge pull request #2332 from RaviTharuma/fix/claude-thinking-signature
fix: preserve Claude thinking signatures in Codex translator
2026-04-06 09:25:06 +08:00
Luis Pater
6f58518c69 docs(readme): remove redundant GITSTORE_GIT_BRANCH description in README files 2026-04-06 09:23:04 +08:00
Luis Pater
000fcb15fa Merge pull request #2298 from snoyiatk/feat/add-gitstore-branch
feat(gitstore): add support for specifying git branch (via GITSTORE_G…
2026-04-06 09:21:03 +08:00
Luis Pater
ea43361492 Merge pull request #2121 from destinoantagonista-wq/main
Reconcile registry model states on auth changes
2026-04-06 09:13:27 +08:00
Luis Pater
c1818f197b Merge pull request #1940 from Blue-B/fix/claude-interleaved-thinking-amp-gzip-budget
fix(claude): enable interleaved-thinking beta, decode AMP error gzip, fix budget 400
2026-04-06 09:08:23 +08:00
Aikins Laryea
b0653cec7b fix(amp): strip signature from tool_use blocks before forwarding to Claude
ensureAmpSignature injects signature:"" into tool_use blocks so the
Amp TUI does not crash on P.signature.length. when Amp sends the
conversation back, Claude rejects the extra field with 400:
  tool_use.signature: Extra inputs are not permitted

strip the proxy-injected signature from tool_use blocks in
SanitizeAmpRequestBody before forwarding to the upstream API.
2026-04-05 12:26:24 +00:00
Luis Pater
22a1a24cf5 feat(executor): add tests for preserving key order in cache control functions
Added comprehensive tests to ensure key order is maintained when modifying payloads in `normalizeCacheControlTTL` and `enforceCacheControlLimit` functions. Removed unused helper functions and refactored implementations for better readability and efficiency.
2026-04-05 17:58:13 +08:00
Luis Pater
7223fee2de Merge branch 'pr-488'
# Conflicts:
#	README.md
#	README_CN.md
#	README_JA.md
2026-04-05 02:08:45 +08:00
Luis Pater
ada8e2905e feat(api): enhance proxy resolution for API key-based auth
Added comprehensive support for resolving proxy URLs from configuration based on API key and provider attributes. Introduced new helper functions and extended the test suite to validate fallback mechanisms and compatibility cases.
2026-04-05 01:56:34 +08:00
Luis Pater
4ba10531da feat(docs): add Poixe AI sponsorship details to README files
Added Poixe AI sponsorship information, including referral bonuses and platform capabilities, to README files in English, Japanese, and Chinese. Updated assets to include Poixe AI logo.
2026-04-05 01:20:50 +08:00
Luis Pater
3774b56e9f feat(misc): add background updater for Antigravity version caching
Introduce `StartAntigravityVersionUpdater` to periodically refresh the cached Antigravity version using a non-blocking background process. Updated main server flow to initialize the updater.
2026-04-04 22:09:11 +08:00
Luis Pater
c2d4137fb9 feat(executor): enhance Qwen system message handling with strict injection and merging rules
Closes: #2537
2026-04-04 21:51:02 +08:00
Luis Pater
2ee938acaf Merge pull request #2535 from rensumo/main
feat: 动态获取 Antigravity User-Agent 版本号
2026-04-04 21:00:47 +08:00
rensumo
8d5e470e1f feat: dynamically fetch antigravity UA version from releases API
Fetch the latest version from the antigravity auto-updater releases
endpoint and cache it for 6 hours. Falls back to 1.21.9 if the API
is unreachable or returns unexpected data.
2026-04-04 14:52:59 +08:00
James
65e9e892a4 Fix missing response.completed.usage for late-usage OpenAI-compatible streams 2026-04-04 05:58:04 +00:00
Luis Pater
3882494878 Merge pull request #486 from router-for-me/plus
v6.9.14
2026-04-04 11:40:13 +08:00
Luis Pater
088c1d07f4 Merge branch 'main' into plus 2026-04-04 11:40:03 +08:00
Luis Pater
8430b28cfa Merge pull request #2526 from rensumo/main
feat: 升级反重力 (antigravity) UA 版本为 1.21.9
2026-04-04 11:32:16 +08:00
rensumo
f3ab8f4bc5 chore: update antigravity UA version to 1.21.9 2026-04-04 07:35:08 +08:00
Luis Pater
0e4f189c2e Merge pull request #1302 from dinhkarate/feat(vertex)/add-prefix-field
Feat(vertex): add prefix field
2026-04-04 04:17:12 +08:00
Luis Pater
98509f615c Merge pull request #485 from kunish/fix/copilot-premium-request-inflation
fix(copilot): reduce premium request inflation, enable thinking, and use dynamic API limits
2026-04-04 02:19:56 +08:00
Luis Pater
e7a66ae504 Merge branch 'router-for-me:main' into main 2026-04-04 02:18:06 +08:00
Luis Pater
754b126944 fix(executor): remove commented-out code in QwenExecutor 2026-04-04 02:14:48 +08:00
Luis Pater
ae37ccffbf Merge pull request #2520 from Arronlong/main
fix:qwen invalid_parameter_error
2026-04-04 02:13:09 +08:00
Luis Pater
42c062bb5b Merge pull request #2509 from adamhelfgott/fix-claude-thinking-temperature
Normalize Claude temperature when thinking is enabled
2026-04-03 23:55:50 +08:00
kunish
87bf0b73d5 fix(copilot): use dynamic API limits to prevent prompt token overflow
The Copilot API enforces per-account prompt token limits (128K individual,
168K business) that differ from the static 200K context length advertised
by the proxy. This mismatch caused Claude Code to accumulate context
beyond the actual limit, triggering "prompt token count exceeds the limit
of 128000" errors.

Changes:
- Extract max_prompt_tokens and max_output_tokens from the Copilot
  /models API response (capabilities.limits) and use them as the
  authoritative ContextLength and MaxCompletionTokens values
- Add CopilotModelLimits struct and Limits() helper to parse limits
  from the existing Capabilities map
- Fix GitLab Duo context-1m beta header not being set when routing
  through the Anthropic gateway (gitlab_duo_force_context_1m attr
  was set but only gin headers were checked)
- Fix flaky parallel tests that shared global model registry state
2026-04-03 23:54:17 +08:00
Luis Pater
f389667ec3 Merge pull request #2513 from lonr-6/codex/fix-ws-custom-tool-repair-v2
fix: repair responses websocket custom tool call pairing
2026-04-03 23:45:38 +08:00
Arronlong
29dba0399b Comment out system message check in Qwen executor
fix qwen invalid_parameter_error
2026-04-03 23:07:33 +08:00
Luis Pater
a824e7cd0b feat(models): add GPT-5.3, GPT-5.4, and GPT-5.4-mini with enhanced "thinking" levels 2026-04-03 23:05:10 +08:00
Luis Pater
140faef7dc Merge branch 'router-for-me:main' into main 2026-04-03 21:48:23 +08:00
Luis Pater
adb580b344 feat(security): add configuration to toggle Gemini CLI endpoint access
Closes: #2445
2026-04-03 21:46:49 +08:00
Luis Pater
06405f2129 fix(security): enforce stricter localhost validation for GeminiCLIAPIHandler
Closes: #2445
2026-04-03 21:22:03 +08:00
kunish
b849bf79d6 fix(copilot): address code review — SSE reasoning, multi-choice, agent detection
- Strip SSE `data:` prefix before normalizing reasoning_text→reasoning_content
  in streaming mode; re-wrap afterward for the translator
- Iterate all choices in normalizeGitHubCopilotReasoningField (not just
  choices[0]) to support n>1 requests
- Remove over-broad tool-role fallback in isAgentInitiated that scanned
  all messages for role:"tool", aligning with opencode's approach of only
  detecting active tool loops — genuine user follow-ups after tool use are
  no longer mis-classified as agent-initiated
- Add 5 reasoning normalization tests; update 2 X-Initiator tests to match
  refined semantics
2026-04-03 20:51:19 +08:00
kunish
59af2c57b1 fix(copilot): reduce premium request inflation and enable thinking
This commit addresses three issues with Claude Code through GitHub
Copilot:

1. **Premium request inflation**: Responses API requests were missing
   Openai-Intent headers and proper defaults, causing Copilot to bill
   each tool-loop continuation as a new premium request. Fixed by adding
   isAgentInitiated() heuristic (checks for tool_result content or
   preceding assistant tool_use), applying Responses API defaults
   (store, include, reasoning.summary), and local tiktoken-based token
   counting to avoid extra API calls.

2. **Context overflow**: Claude Code's modelSupports1M() hardcodes
   opus-4-6 as 1M-capable, but Copilot only supports ~128K-200K.
   Fixed by stripping the context-1m-2025-08-07 beta from translated
   request bodies. Also forwards response headers in non-streaming
   Execute() and registers the GET /copilot-quota management API route.

3. **Thinking not working**: Add ThinkingSupport with level-based
   reasoning to Claude models in the static definitions. Normalize
   Copilot's non-standard 'reasoning_text' response field to
   'reasoning_content' before passing to the SDK translator. Use
   caller-provided context in CountTokens instead of Background().
2026-04-03 20:24:30 +08:00
Kai Wang
d1fd2c4ad4 fix: repair websocket custom tool calls 2026-04-03 17:11:44 +08:00
Kai Wang
b6c6379bfa fix: repair websocket custom tool calls 2026-04-03 17:11:42 +08:00
Kai Wang
8f0e66b72e fix: repair websocket custom tool calls 2026-04-03 17:11:41 +08:00
Adam Helfgott
f63cf6ff7a Normalize Claude temperature for thinking 2026-04-03 03:45:51 -04:00
Luis Pater
d2419ed49d feat(executor): ensure default system message in QwenExecutor payload 2026-04-03 11:18:48 +08:00
Luis Pater
516d22c695 Merge pull request #484 from Ve-ria/main
更新CodeBuddy CN的模型列表
2026-04-03 11:10:32 +08:00
rensumo
73cda6e836 Update CodeBuddy DeepSeek model description 2026-04-03 11:03:33 +08:00
rensumo
0805989ee5 更新CodeBuddy CN的模型列表 2026-04-03 10:59:27 +08:00
mpfo0106
9b5ce8c64f Keep Claude builtin helpers aligned with the shared helper layout
The review asked for the builtin tool registry helper to live with the rest
of executor support utilities. This moves the registry code into the helps
package, exports the minimal surface executor needs, and keeps behavior tests
with the executor while leaving registry-focused checks with the helper.

Constraint: Requested layout keeps executor helper utilities centralized under internal/runtime/executor/helps
Rejected: Keep the files in executor and reply with rationale | conflicts with requested package organization
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Keep executor behavior tests near applyClaudeToolPrefix and keep pure registry tests in helps
Tested: go test ./internal/runtime/executor/helps ./internal/runtime/executor -run 'Claude|Builtin|Tool'; go test ./test/...; go test ./...
Not-tested: End-to-end Claude Code direct-connect/session runtime behavior
2026-04-03 00:13:02 +09:00
Duong M. CUONG
058793c73a feat(gitstore): honor configured branch and follow live remote default 2026-04-02 14:44:44 +00:00
Luis Pater
75da02af55 Merge branch 'router-for-me:main' into main 2026-04-02 22:34:47 +08:00
Luis Pater
ab9ebea592 Merge PR #2474
# Conflicts:
#	internal/api/modules/amp/response_rewriter.go
#	internal/api/modules/amp/response_rewriter_test.go
2026-04-02 22:31:12 +08:00
Luis Pater
7ee37ee4b9 feat: add /healthz endpoint and test coverage for health check
Closes: #2493
2026-04-02 21:56:27 +08:00
Luis Pater
837afffb31 docs: remove README_JA.md and clean up related links from README files 2026-04-02 21:37:47 +08:00
Luis Pater
03a1bac898 Merge upstream v6.9.9 (PR #483) 2026-04-02 21:31:21 +08:00
Luis Pater
3171d524f0 docs: fix duplicated ProxyPal entry in README files 2026-04-02 21:22:40 +08:00
Luis Pater
3e78a8d500 Merge branch 'main' into dev 2026-04-02 21:21:26 +08:00
Luis Pater
fcba912cc4 Merge pull request #2492 from davidwushi1145/main
fix(responses): reassemble split SSE event/data frames before streaming
2026-04-02 21:20:31 +08:00
Luis Pater
7170eeea5f Merge pull request #2454 from buddingnewinsights/add-proxypal-to-readme
docs: add ProxyPal to "Who is with us?" section
2026-04-02 21:18:22 +08:00
Luis Pater
e3eb048c7a Merge pull request #2489 from Soein/upstream-pr
fix: 增强 Claude 反代检测对抗能力
2026-04-02 21:16:58 +08:00
Luis Pater
a59e92435b Merge pull request #2490 from router-for-me/logs
Refactor websocket logging and error handling
2026-04-02 20:47:31 +08:00
davidwushi1145
108895fc04 Harden Responses SSE framing against partial chunk boundaries
Follow-up review found two real framing hazards in the handler-layer
framer: it could flush a partial `data:` payload before the JSON was
complete, and it could inject an extra newline before chunks that
already began with `\n`/`\r\n`. This commit tightens the framer so it
only emits undelimited events when the buffered `data:` payload is
already valid JSON (or `[DONE]`), skips newline injection for chunks
that already start with a line break, and avoids the heavier
`bytes.Split` path while scanning SSE fields.

The regression suite now covers split `data:` payload chunks,
newline-prefixed chunks, and dropping incomplete trailing data on
flush, so the original Responses fix remains intact while the review
concerns are explicitly locked down.

Constraint: Keep the follow-up limited to handler-layer framing and tests
Rejected: Ignore the review and rely on current executor chunk shapes | leaves partial data payload corruption possible
Rejected: Build a fully generic SSE parser | wider change than needed for the identified risks
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Do not emit undelimited Responses SSE events unless buffered `data:` content is already complete and valid
Tested: /tmp/go1.26.1/go/bin/go test ./sdk/api/handlers/openai -count=1
Tested: /tmp/go1.26.1/go/bin/go test ./sdk/api/handlers -count=1
Tested: /tmp/go1.26.1/go/bin/go vet ./sdk/api/handlers/...
Not-tested: Full repository test suite outside sdk/api/handlers packages
2026-04-02 20:39:49 +08:00
davidwushi1145
abc293c642 Prevent malformed Responses SSE frames from breaking stream clients
Line-oriented upstream executors can emit `event:` and `data:` as
separate chunks, but the Responses handler had started terminating
each incoming chunk as a full SSE event. That split `response.created`
into an empty event plus a later data block, which broke downstream
clients like OpenClaw.

This keeps the fix in the handler layer: a small stateful framer now
buffers standalone `event:` lines until the matching `data:` arrives,
preserves already-framed events, and ignores delimiter-only leftovers.
The regression suite now covers split event/data framing, full-event
passthrough, terminal errors, and the bootstrap path that forwards
line-oriented openai-response streams from non-Codex executors too.

Constraint: Keep the fix localized to Responses handler framing instead of patching every executor
Rejected: Revert to v6.9.7 chunk writing | would reintroduce data-only framing regressions
Rejected: Patch each line-oriented executor separately | duplicates fragile SSE assembly logic
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Do not assume incoming Responses stream chunks are already complete SSE events; preserve handler-layer reassembly for split `event:`/`data:` inputs
Tested: /tmp/go1.26.1/go/bin/go test ./sdk/api/handlers/openai -count=1
Tested: /tmp/go1.26.1/go/bin/go test ./sdk/api/handlers -count=1
Tested: /tmp/go1.26.1/go test ./sdk/api/handlers/... -count=1
Tested: /tmp/go1.26.1/go/bin/go vet ./sdk/api/handlers/...
Tested: Temporary patched server on 127.0.0.1:18317 -> /v1/models 200, /v1/responses non-stream 200, /v1/responses stream emitted combined `event:` + `data:` frames
Not-tested: Full repository test suite outside sdk/api/handlers packages
2026-04-02 20:26:42 +08:00
mpfo0106
da3a498a28 Keep Claude Code compatibility work low-risk and reviewable
This change stops short of broader Claude Code runtime alignment and instead
hardens two safe edges: builtin tool prefix handling and source-informed
sentinel coverage for future drift checks.

Constraint: Must preserve existing default behavior for current users
Rejected: Implement control-plane/session alignment now | too much runtime risk for a first slice
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Treat the new fixtures as compatibility sentinels, not a full Claude Code schema contract
Tested: go test ./test/...; go test ./sdk/translator/...; go test ./internal/runtime/executor -run 'Claude|Builtin|Tool'; go test ./...
Not-tested: End-to-end Claude Code direct-connect/session runtime behavior
2026-04-02 20:35:39 +09:00
pzy
bb44671845 fix: 修复反代检测对抗的 3 个问题
- computeFingerprint 使用 rune 索引替代字节索引,修复多字节字符指纹不匹配
- utls Chrome TLS 指纹仅对 Anthropic 官方域名生效,自定义 base_url 走标准 transport
- IPv6 地址使用 net.JoinHostPort 正确拼接端口
2026-04-02 19:12:55 +08:00
Luis Pater
09e480036a feat(auth): add support for managing custom headers in auth files
Closes #2457
2026-04-02 19:11:09 +08:00
pzy
249f969110 fix: Claude API 请求使用 utls Chrome TLS 指纹
Claude executor 的 API 请求之前使用 Go 标准库 crypto/tls,JA3 指纹
与真实 Claude Code(Bun/BoringSSL)不匹配,可被 Cloudflare 识别。

- 新增 helps/utls_client.go,封装 utls Chrome 指纹 + HTTP/2 + 代理支持
- Claude executor 的 4 处 NewProxyAwareHTTPClient 替换为 NewUtlsHTTPClient
- 其他 executor(Gemini/Codex/iFlow 等)不受影响,仍用标准 TLS
- 非 HTTPS 请求自动回退到标准 transport
2026-04-02 19:09:56 +08:00
hkfires
4f8acec2d8 refactor(logging): centralize websocket handshake recording 2026-04-02 18:39:32 +08:00
hkfires
34339f61ee Refactor websocket logging and error handling
- Introduced new logging functions for websocket requests, handshakes, errors, and responses in `logging_helpers.go`.
- Updated `CodexWebsocketsExecutor` to utilize the new logging functions for improved clarity and consistency in websocket operations.
- Modified the handling of websocket upgrade rejections to log relevant metadata.
- Changed the request body key to a timeline body key in `openai_responses_websocket.go` to better reflect its purpose.
- Enhanced tests to verify the correct logging of websocket events and responses, including disconnect events and error handling scenarios.
2026-04-02 17:30:51 +08:00
pzy
4045378cb4 fix: 增强 Claude 反代检测对抗能力
基于 Claude Code v2.1.88 源码分析,修复多个可被 Anthropic 检测的差距:

- 实现消息指纹算法(SHA256 盐值 + 字符索引),替代随机 buildHash
- billing header cc_version 从设备 profile 动态取版本号,不再硬编码
- billing header cc_entrypoint 从客户端 UA 解析,支持 cli/vscode/local-agent
- billing header 新增 cc_workload 支持(通过 X-CPA-Claude-Workload 头传入)
- 新增 X-Claude-Code-Session-Id 头(每 apiKey 缓存 UUID,TTL=1h)
- 新增 x-client-request-id 头(仅 api.anthropic.com,每请求 UUID)
- 补全 4 个缺失的 beta flags(structured-outputs/fast-mode/redact-thinking/token-efficient-tools)
- OAuth scope 对齐 Claude Code 2.1.88(移除 org:create_api_key,添加 sessions/mcp/file_upload)
- Anthropic-Dangerous-Direct-Browser-Access 仅在 API key 模式发送
- 响应头网关指纹清洗(剥离 litellm/helicone/portkey/cloudflare/kong/braintrust 前缀头)
2026-04-02 15:55:22 +08:00
Luis Pater
2df35449fe Fix executor compat helpers 2026-04-02 12:20:12 +08:00
Luis Pater
c744179645 Merge PR #479 2026-04-02 12:15:33 +08:00
Luis Pater
9720b03a6b Merge pull request #477 from ben-vargas/plus-main
fix(copilot): route Gemini preview models to chat endpoint and correct context lengths
2026-04-02 11:36:51 +08:00
Luis Pater
f2c0f3d325 Merge pull request #476 from hungthai1401/fix/ghc-gpt54mini
Fix GitHub Copilot gpt-5.4-mini endpoint routing
2026-04-02 11:36:26 +08:00
Luis Pater
4f99bc54f1 test: update codex header expectations 2026-04-02 11:19:37 +08:00
Luis Pater
913f4a9c5f test: fix executor tests after helpers refactor 2026-04-02 11:12:30 +08:00
Luis Pater
25d1c18a3f fix: scope experimental cch signing to billing header 2026-04-02 11:03:11 +08:00
Luis Pater
d09dd4d0b2 Merge commit '15c2f274ea690c9a7c9db22f9f454af869db5375' into dev 2026-04-02 10:59:54 +08:00
Luis Pater
474fb042da Merge pull request #2476 from router-for-me/cherry-pick/pr-2438-to-dev
Cherry-pick PR #2438 onto dev
2026-04-02 10:36:50 +08:00
Michael
8435c3d7be feat(tui): show time in usage details 2026-04-02 10:35:13 +08:00
Luis Pater
e783d0a62e Merge pull request #2441 from MonsterQiu/issue-2421-alias-before-suspension
fix(auth): resolve oauth aliases before suspension checks
2026-04-02 10:27:39 +08:00
Luis Pater
b05f575e9b Merge pull request #2444 from 0oAstro/fix/codex-nonstream-finish-reason-tool-calls
fix(codex): set finish_reason to "tool_calls" in non-streaming response when tool calls are present
2026-04-02 10:01:25 +08:00
Aikins Laryea
f5e9f01811 test(amp): update tests to expect thinking blocks to pass through during streaming 2026-04-01 20:35:23 +00:00
Aikins Laryea
ff7dbb5867 test(amp): update tests to expect thinking blocks to pass through during streaming 2026-04-01 20:21:39 +00:00
Aikins Laryea
e34b2b4f1d fix(gemini): clean tool schemas and eager_input_streaming
delegate schema sanitization to util.CleanJSONSchemaForGemini and drop the top-level eager_input_streaming key to prevent validation errors when sending claude tools to the gemini api
2026-04-01 19:49:38 +00:00
edlsh
15c2f274ea fix: preserve cloak config defaults when mode omitted 2026-04-01 13:20:11 -04:00
edlsh
37249339ac feat: add opt-in experimental Claude cch signing 2026-04-01 13:03:17 -04:00
Luis Pater
c422d16beb Merge pull request #2398 from 7RPH/fix/responses-sse-framing
fix: preserve SSE event boundaries for Responses streams
2026-04-02 00:46:51 +08:00
Luis Pater
66cd50f603 Merge pull request #2468 from router-for-me/ip
fix(openai): improve client IP retrieval in websocket handler
2026-04-02 00:03:35 +08:00
hkfires
caa529c282 fix(openai): improve client IP retrieval in websocket handler 2026-04-01 20:16:01 +08:00
hkfires
51a4379bf4 refactor(openai): remove websocket body log truncation limit 2026-04-01 18:11:43 +08:00
Luis Pater
acf98ed10e fix(openai): add session reference counter and cache lifecycle management for websocket tools 2026-04-01 17:28:50 +08:00
Luis Pater
d1c07a091e fix(openai): add websocket tool call repair with caching and tests to improve transcript consistency 2026-04-01 17:16:49 +08:00
Ben Vargas
c1a8adf1ab feat(registry): add GitHub Copilot gemini-3.1-pro-preview model 2026-04-01 01:25:03 -06:00
Ben Vargas
08e078fc25 fix(openai): route copilot Gemini preview models to chat endpoint 2026-04-01 01:24:58 -06:00
Luis Pater
105a21548f fix(codex): centralize session management with global store and add tests for executor session lifecycle 2026-04-01 13:17:10 +08:00
Luis Pater
1734aa1664 fix(codex): prioritize websocket-enabled credentials across priority tiers in scheduler logic 2026-04-01 12:51:12 +08:00
Luis Pater
ca11b236a7 refactor(runtime, openai): simplify header management and remove redundant websocket logging logic 2026-04-01 11:57:31 +08:00
huynhgiabuu
6fdff8227d docs: add ProxyPal to 'Who is with us?' section
Add ProxyPal (https://github.com/buddingnewinsights/proxypal) to the
community projects list in all three README files (EN, CN, JA).
Placed after CCS, restoring its original position.

ProxyPal is a cross-platform desktop app (macOS, Windows, Linux) that
wraps CLIProxyAPI with a native GUI, supporting multiple AI providers,
usage analytics, request monitoring, and auto-configuration for popular
coding tools.

Closes #2420
2026-04-01 10:23:22 +07:00
Luis Pater
330e12d3c2 fix(codex): conditionally set Session_id header for Mac OS user agents and clean up redundant logic 2026-04-01 11:11:45 +08:00
Thai Nguyen Hung
bd09c0bf09 feat(registry): add gpt-5.4-mini model to GitHub Copilot registry 2026-04-01 10:04:38 +07:00
Luis Pater
b468ca79c3 Merge branch 'dev' of github.com:router-for-me/CLIProxyAPI into dev 2026-04-01 03:09:03 +08:00
Luis Pater
d2c7e4e96a refactor(runtime): move executor utilities to helps package and update references 2026-04-01 03:08:20 +08:00
Luis Pater
1c7003ff68 Merge pull request #2452 from Lucaszmv/fix-qwen-cli-v0.13.2
fix(qwen): update CLI simulation to v0.13.2 and adjust header casing
2026-04-01 02:44:27 +08:00
Lucaszmv
1b44364e78 fix(qwen): update CLI simulation to v0.13.2 2026-03-31 15:19:07 -03:00
0oAstro
ec77f4a4f5 fix(codex): set finish_reason to tool_calls in non-streaming response when tool calls are present 2026-03-31 14:12:15 +05:30
MonsterQiu
f611dd6e96 refactor(auth): dedupe route-aware model support checks 2026-03-31 15:42:25 +08:00
MonsterQiu
07b7c1a1e0 fix(auth): resolve oauth aliases before suspension checks 2026-03-31 14:27:14 +08:00
Luis Pater
51fd58d74f fix(codex): use normalizeCodexInstructions to set default instructions 2026-03-31 12:16:57 +08:00
Luis Pater
faae9c2f7c Merge pull request #2422 from MonsterQiu/fix/codex-compact-instructions
fix(codex): add default instructions for /responses/compact
2026-03-31 12:14:20 +08:00
Luis Pater
bc3a6e4646 Merge pull request #2434 from MonsterQiu/fix/codex-responses-null-instructions
fix(codex): normalize null instructions for /responses requests
2026-03-31 12:01:21 +08:00
Luis Pater
b09b03e35e Merge pull request #2424 from possible055/fix/websocket-transcript-replacement
fix(openai): handle transcript replacement after websocket v2 compaction
2026-03-31 11:00:33 +08:00
Luis Pater
16231947e7 Merge pull request #2426 from xixiwenxuanhe/feature/antigravity-credits
feat(antigravity): add AI credits quota fallback
2026-03-31 10:51:40 +08:00
MonsterQiu
39b9a38fbc fix(codex): normalize null instructions across responses paths 2026-03-31 10:32:39 +08:00
MonsterQiu
bd855abec9 fix(codex): normalize null instructions for responses requests 2026-03-31 10:29:02 +08:00
Luis Pater
7c3c2e9f64 Merge pull request #2417 from CharTyr/fix/amp-streaming-thinking-regression
fix(amp): 修复流式响应中 thinking block 被错误抑制导致的 TUI 空白回复
2026-03-31 10:12:13 +08:00
Luis Pater
c10f8ae2e2 Fixed: #2420
docs(readme): remove ProxyPal section from all README translations
2026-03-31 07:23:02 +08:00
xixiwenxuanhe
a0bf33eca6 fix(antigravity): preserve fallback and honor config gate 2026-03-31 00:14:05 +08:00
xixiwenxuanhe
88dd9c715d feat(antigravity): add AI credits quota fallback 2026-03-30 23:58:12 +08:00
apparition
a3e21df814 fix(openai): avoid developer transcript resets
- Narrow websocket transcript replacement detection to assistant outputs and function calls
- Preserve existing merge behavior for follow-up developer messages without previous_response_id
- Add a regression test covering mid-session developer message updates
2026-03-30 23:33:16 +08:00
MonsterQiu
d3b94c9241 fix(codex): normalize null instructions for compact requests 2026-03-30 22:58:05 +08:00
apparition
c1d7599829 fix(openai): handle transcript replacement after websocket compaction
- Add shouldReplaceWebsocketTranscript() to detect historical model output in input
- Add normalizeResponseTranscriptReplacement() for full transcript reset handling
- Prevent duplicate stale turn-state when clients replace local history post-compaction
- Avoid orphaned function_call items from incremental append on compact transcripts
- Add unit tests for transcript replacement detection and state reset behavior
2026-03-30 22:44:58 +08:00
MonsterQiu
d11936f292 fix(codex): add default instructions for /responses/compact 2026-03-30 22:44:46 +08:00
Luis Pater
17363edf25 fix(auth): skip downtime for request-scoped 404 errors in model state management 2026-03-30 22:22:42 +08:00
CharTyr
279cbbbb8a fix(amp): don't suppress thinking blocks in streaming mode
Reverts the streaming thinking suppression introduced in b15453c.
rewriteStreamEvent should only inject signatures and rewrite model
names — suppressing thinking blocks in streaming mode breaks SSE
index alignment and causes the Amp TUI to render empty responses
on the second message onward (especially with model-mapped
non-Claude providers like GPT-5.4).

Non-streaming responses still suppress thinking when tool_use is
present via rewriteModelInResponse.
2026-03-30 20:09:32 +08:00
Luis Pater
486cd4c343 Merge pull request #2409 from sususu98/fix/tool-use-pairing-break
fix(antigravity): reorder model parts to prevent tool_use↔tool_result pairing breakage
2026-03-30 16:59:46 +08:00
sususu98
25feceb783 fix(antigravity): reorder model parts to prevent tool_use↔tool_result pairing breakage
When a Claude assistant message contains [text, tool_use, text], the
Antigravity API internally splits the model message at functionCall
boundaries, creating an extra assistant turn between tool_use and the
following tool_result. Claude then rejects with:

  tool_use ids were found without tool_result blocks immediately after

Fix: extend the existing 2-way part reordering (thinking-first) to a
3-way partition: thinking → regular → functionCall. This ensures
functionCall parts are always last, so Antigravity's split cannot
insert an extra assistant turn before the user's tool_result.

Fixes #989
2026-03-30 15:09:33 +08:00
Luis Pater
d26752250d Merge pull request #2403 from CharTyr/clean-pr
fix(amp): 修复Amp CLI 集成 缺失/无效 signature 导致的 TUI 崩溃与上游 400 问题
2026-03-30 12:54:15 +08:00
CharTyr
b15453c369 fix(amp): address PR review - stream thinking suppression, SSE detection, test init
- Call suppressAmpThinking in rewriteStreamEvent for streaming path
- Handle nil return from suppressAmpThinking to skip suppressed events
- Narrow looksLikeSSEChunk to line-prefix detection (HasPrefix vs Contains)
- Initialize suppressedContentBlock map in test
2026-03-30 00:42:04 -04:00
CharTyr
04ba8c8bc3 feat(amp): sanitize signatures and handle stream suppression for Amp compatibility 2026-03-29 22:23:18 -04:00
Luis Pater
6570692291 Merge pull request #2400 from router-for-me/revert-2374-codex-cache-clean
Revert "fix(codex): restore prompt cache continuity for Codex requests"
2026-03-29 22:19:39 +08:00
trph
f73d55ddaa fix: simplify responses SSE suffix handling 2026-03-29 22:19:25 +08:00
Luis Pater
13aa5b3375 Revert "fix(codex): restore prompt cache continuity for Codex requests" 2026-03-29 22:18:14 +08:00
trph
0fcc02fbea fix: tighten responses SSE review follow-up 2026-03-29 22:10:28 +08:00
trph
c03883ccf0 fix: address responses SSE review feedback 2026-03-29 22:00:46 +08:00
trph
134a9eac9d fix: preserve SSE event boundaries for Responses streams 2026-03-29 17:23:16 +08:00
Luis Pater
6d8de0ade4 feat(auth): implement weighted provider rotation for improved scheduling fairness 2026-03-29 13:49:01 +08:00
Luis Pater
1587ff5e74 Merge pull request #2389 from router-for-me/claude
fix(claude): add default max_tokens for models
2026-03-29 13:03:20 +08:00
hkfires
f033d3a6df fix(claude): enhance ensureModelMaxTokens to use registered max_completion_tokens and fallback to default 2026-03-29 13:00:43 +08:00
hkfires
145e0e0b5d fix(claude): add default max_tokens for models 2026-03-29 12:46:00 +08:00
Luis Pater
f8d1bc06ea Merge pull request #469 from router-for-me/plus
v6.9.5
2026-03-29 12:40:26 +08:00
Luis Pater
d5930f4e44 Merge branch 'main' into plus 2026-03-29 12:40:17 +08:00
Luis Pater
9b7d7021af docs(readme): update LingtrueAPI link in all README translations 2026-03-29 12:30:24 +08:00
Luis Pater
e41c22ef44 docs(readme): add LingtrueAPI sponsorship details to all README translations 2026-03-29 12:23:37 +08:00
Ravi Tharuma
5fc2bd393e fix: retain codex thinking signature until item done 2026-03-28 14:41:25 +01:00
Luis Pater
55271403fb Merge pull request #2374 from VooDisss/codex-cache-clean
fix(codex): restore prompt cache continuity for Codex requests
2026-03-28 21:16:51 +08:00
Luis Pater
36fba66619 Merge pull request #2371 from RaviTharuma/docs/provider-specific-routes
docs: clarify provider-specific routing for aliased models
2026-03-28 21:11:29 +08:00
Ravi Tharuma
66eb12294a fix: clear stale thinking signature when no block is open 2026-03-28 14:08:31 +01:00
Ravi Tharuma
73b22ec29b fix: omit empty signature field from thinking blocks
Emit signature only when non-empty in both streaming content_block_start
and non-streaming thinking blocks. Avoids turning 'missing signature'
into 'empty/invalid signature' which Claude clients may reject.
2026-03-28 14:08:31 +01:00
Ravi Tharuma
c31ae2f3b5 fix: retain previously captured thinking signature on new summary part 2026-03-28 14:08:31 +01:00
Ravi Tharuma
76b53d6b5b fix: finalize pending thinking block before next summary part 2026-03-28 14:08:31 +01:00
Ravi Tharuma
a34dfed378 fix: preserve Claude thinking signatures in Codex translator 2026-03-28 14:08:31 +01:00
Luis Pater
b9b127a7ea Merge pull request #2347 from edlsh/fix/codex-strip-stream-options
fix(codex): strip stream_options from Responses API requests
2026-03-28 21:03:01 +08:00
Luis Pater
2741e7b7b3 Merge pull request #2346 from pjpjq/codex/fix-codex-capacity-retry
fix(codex): Treat Codex capacity errors as retryable
2026-03-28 21:00:50 +08:00
Luis Pater
1767a56d4f Merge pull request #2343 from kongkk233/fix/proxy-transport-defaults
Preserve default transport settings for proxy clients
2026-03-28 20:58:24 +08:00
Luis Pater
779e6c2d2f Merge pull request #2231 from 7RPH/fix/responses-stream-multi-tool-calls
fix: preserve separate streamed tool calls in Responses API
2026-03-28 20:53:19 +08:00
Luis Pater
73c831747b Merge pull request #2133 from DragonFSKY/fix/2061-stale-modelstates
fix(auth): prevent stale runtime state inheritance from disabled auth entries
2026-03-28 20:50:57 +08:00
Luis Pater
b8b89f34f4 Merge pull request #442 from LuxVTZ/feat/gitlab-duo-panel-parity
Improve GitLab Duo gateway compatibility\n\nRestore internal/runtime/executor/claude_executor.go to main during merge.
2026-03-28 05:06:41 +08:00
VooDisss
e5d3541b5a refactor(codex): remove stale affinity cleanup leftovers
Drop the last affinity-related executor artifacts so the PR stays focused on the minimal Codex continuity fix set: stable prompt cache identity, stable session_id, and the executor-only behavior that was validated to restore cache reads.
2026-03-27 20:40:26 +02:00
VooDisss
79755e76ea refactor(pr): remove forbidden translator changes
Drop the chat-completions translator edits from this PR so the branch complies with the repository policy that forbids pull-request changes under internal/translator. The remaining PR stays focused on the executor-level Codex continuity fix that was validated to restore cache reuse.
2026-03-27 19:34:13 +02:00
VooDisss
35f158d526 refactor(pr): narrow Codex cache fix scope
Remove the experimental auth-affinity routing changes from this PR so it stays focused on the validated Codex continuity fix. This keeps the prompt-cache repair while avoiding unrelated routing-policy concerns such as provider/model affinity scope, lifecycle cleanup, and hard-pin fallback semantics.
2026-03-27 19:06:34 +02:00
VooDisss
6962e09dd9 fix(auth): scope affinity by provider
Keep sticky auth affinity limited to matching providers and stop persisting execution-session IDs as long-lived affinity keys so provider switching and normal streaming traffic do not create incorrect pins or stale affinity state.
2026-03-27 18:52:58 +02:00
VooDisss
4c4cbd44da fix(auth): avoid leaking or over-persisting affinity keys
Stop using one-shot idempotency keys as long-lived auth-affinity identifiers and remove raw affinity-key values from debug logs so sticky routing keeps its continuity benefits without creating avoidable memory growth or credential exposure risks.
2026-03-27 18:34:51 +02:00
VooDisss
26eca8b6ba fix(codex): preserve continuity and safe affinity fallback
Restore Claude continuity after the continuity refactor, keep auth-affinity keys out of upstream Codex session identifiers, and only persist affinity after successful execution so retries can still rotate to healthy credentials when the first auth fails.
2026-03-27 18:27:33 +02:00
VooDisss
62b17f40a1 refactor(codex): align continuity helpers with review feedback
Align websocket continuity resolution with the HTTP Codex path, make auth-affinity principal keys use a stable string representation, and extract small helpers that remove duplicated continuity and affinity logic without changing the validated cache-hit behavior.
2026-03-27 18:11:57 +02:00
VooDisss
511b8a992e fix(codex): restore prompt cache continuity for Codex requests
Prompt caching on Codex was not reliably reusable through the proxy because repeated chat-completions requests could reach the upstream without the same continuity envelope. In practice this showed up most clearly with OpenCode, where cache reads worked in the reference client but not through CLIProxyAPI, although the root cause is broader than OpenCode itself.

The proxy was breaking continuity in several ways: executor-layer Codex request preparation stripped prompt_cache_retention, chat-completions translation did not preserve that field, continuity headers used a different shape than the working client behavior, and OpenAI-style Codex requests could be sent without a stable prompt_cache_key. When that happened, session_id fell back to a fresh random value per request, so upstream Codex treated repeated requests as unrelated turns instead of as part of the same cacheable context.

This change fixes that by preserving caller-provided prompt_cache_retention on Codex execution paths, preserving prompt_cache_retention when translating OpenAI chat-completions requests to Codex, aligning Codex continuity headers to session_id, and introducing an explicit Codex continuity policy that derives a stable continuity key from the best available signal. The resolution order prefers an explicit prompt_cache_key, then execution session metadata, then an explicit idempotency key, then stable request-affinity metadata, then a stable client-principal hash, and finally a stable auth-ID hash when no better continuity signal exists.

The same continuity key is applied to both prompt_cache_key in the request body and session_id in the request headers so repeated requests reuse the same upstream cache/session identity. The auth manager also keeps auth selection sticky for repeated request sequences, preventing otherwise-equivalent Codex requests from drifting across different upstream auth contexts and accidentally breaking cache reuse.

To keep the implementation maintainable, the continuity resolution and diagnostics are centralized in a dedicated Codex continuity helper instead of being scattered across executor flow code. Regression coverage now verifies retention preservation, continuity-key precedence, stable auth-ID fallback, websocket parity, translator preservation, and auth-affinity behavior. Manual validation confirmed prompt cache reads now occur through CLIProxyAPI when using Codex via OpenCode, and the fix should also benefit other clients that rely on stable repeated Codex request continuity.
2026-03-27 17:49:29 +02:00
Ravi Tharuma
0ab977c236 docs: clarify provider path limitations 2026-03-27 11:13:08 +01:00
Ravi Tharuma
224f0de353 docs: neutralize provider-specific path wording 2026-03-27 11:11:06 +01:00
Ravi Tharuma
d54de441d3 docs: clarify provider-specific routing for aliased models 2026-03-27 10:53:09 +01:00
edlsh
754f3bcbc3 fix(codex): strip stream_options from Responses API requests
The Codex/OpenAI Responses API does not support the stream_options
parameter. When clients (e.g. Amp CLI) include stream_options in their
requests, CLIProxyAPI forwards it as-is, causing a 400 error:

  {"detail":"Unsupported parameter: stream_options"}

Strip stream_options alongside the other unsupported parameters
(previous_response_id, prompt_cache_retention, safety_identifier)
in Execute, ExecuteStream, and CountTokens.
2026-03-25 11:58:36 -04:00
pjpj
36973d4a6f Handle Codex capacity errors as retryable 2026-03-25 23:25:31 +08:00
kwz
c89d19b300 Preserve default transport settings for proxy clients 2026-03-25 15:33:09 +08:00
trph
cc32f5ff61 fix: unify Responses output indexes for streamed items 2026-03-24 08:59:09 +08:00
trph
fbff68b9e0 fix: preserve choice-aware output indexes for streamed tool calls 2026-03-24 08:54:43 +08:00
trph
7e1a543b79 fix: preserve separate streamed tool calls in Responses API 2026-03-24 08:51:15 +08:00
DragonFSKY
74b862d8b8 test(cliproxy): cover delete re-add stale state flow 2026-03-24 00:21:04 +08:00
dinhkarate
36efcc6e28 fix(vertex): include prefix in auth filename and validate at import
Address two blocking issues from PR review:
- Auth file now named vertex-{prefix}-{project}.json so importing the
  same project with different prefixes no longer overwrites credentials
- Prefix containing "/" is rejected at import time instead of being
  silently ignored at runtime
- Add prefix to in-memory metadata map for consistency

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 15:06:04 +07:00
Pham Quang Dinh
a337ecf35c Merge branch 'router-for-me:main' into feat(vertex)/add-prefix-field 2026-03-17 11:48:40 +07:00
DragonFSKY
5c817a9b42 fix(auth): prevent stale ModelStates inheritance from disabled auth entries
When an auth file is deleted and re-created with the same path/ID, the
new auth could inherit stale ModelStates (cooldown/backoff) from the
previously disabled entry, preventing it from being routed.

Gate runtime state inheritance (ModelStates, LastRefreshedAt,
NextRefreshAfter) on both existing and incoming auth being non-disabled
in Manager.Update and Service.applyCoreAuthAddOrUpdate.

Closes #2061
2026-03-14 23:46:23 +08:00
destinoantagonista-wq
e08f68ed7c chore(auth): drop reconcile test file from pr 2026-03-14 14:41:26 +00:00
destinoantagonista-wq
f09ed25fd3 fix(auth): tighten registry model reconciliation 2026-03-14 14:40:06 +00:00
luxvtz
5da0decef6 Improve GitLab Duo gateway compatibility 2026-03-14 03:18:43 -07:00
destinoantagonista-wq
e166e56249 Reconcile registry model states on auth changes
Add Manager.ReconcileRegistryModelStates to clear stale per-model runtime failures for models currently registered in the global model registry. The method finds models supported for an auth, resets non-clean ModelState entries, updates aggregated availability, persists changes, and pushes a snapshot to the scheduler. Introduce modelStateIsClean helper to determine when a model state needs resetting. Call ReconcileRegistryModelStates from Service paths that register/refresh models (applyCoreAuthAddOrUpdate and refreshModelRegistrationForAuth) to keep the scheduler and global registry aligned after model re-registration.
2026-03-13 19:41:49 +00:00
Blue-B
5f58248016 fix(claude): clamp max_tokens to model limit in normalizeClaudeBudget
When adjustedBudget < minBudget, the previous fix blindly set
max_tokens = budgetTokens+1 which could exceed MaxCompletionTokens.

Now: cap max_tokens at MaxCompletionTokens, recalculate budget, and
disable thinking entirely if constraints are unsatisfiable.

Add unit tests covering raise, clamp, disable, and no-op scenarios.
2026-03-09 22:10:30 +09:00
Blue-B
07d6689d87 fix(claude): add interleaved-thinking beta header, AMP gzip error decoding, normalizeClaudeBudget max_tokens
1. Always include interleaved-thinking-2025-05-14 beta header so that
   thinking blocks are returned correctly for all Claude models.

2. Remove status-code guard in AMP reverse proxy ModifyResponse so that
   error responses (4xx/5xx) with hidden gzip encoding are decoded
   properly — prevents garbled error messages reaching the client.

3. In normalizeClaudeBudget, when the adjusted budget falls below the
   model minimum, set max_tokens = budgetTokens+1 instead of leaving
   the request unchanged (which causes a 400 from the API).
2026-03-07 21:31:10 +09:00
dinhkarate
14cb2b95c6 feat(vertex): add --vertex-import-prefix flag for model namespacing 2026-01-29 13:32:38 +07:00
dinhkarate
fdeef48498 feat(vertex): Add Prefix field to VertexCredentialStorage for per-file model namespacing 2026-01-29 13:32:38 +07:00
142 changed files with 14381 additions and 2701 deletions

81
.github/workflows/agents-md-guard.yml vendored Normal file
View File

@@ -0,0 +1,81 @@
name: agents-md-guard
on:
pull_request_target:
types:
- opened
- synchronize
- reopened
permissions:
contents: read
issues: write
pull-requests: write
jobs:
close-when-agents-md-changed:
runs-on: ubuntu-latest
steps:
- name: Detect AGENTS.md changes and close PR
uses: actions/github-script@v7
with:
script: |
const prNumber = context.payload.pull_request.number;
const { owner, repo } = context.repo;
const files = await github.paginate(github.rest.pulls.listFiles, {
owner,
repo,
pull_number: prNumber,
per_page: 100,
});
const touchesAgentsMd = (path) =>
typeof path === "string" &&
(path === "AGENTS.md" || path.endsWith("/AGENTS.md"));
const touched = files.filter(
(f) => touchesAgentsMd(f.filename) || touchesAgentsMd(f.previous_filename),
);
if (touched.length === 0) {
core.info("No AGENTS.md changes detected.");
return;
}
const changedList = touched
.map((f) =>
f.previous_filename && f.previous_filename !== f.filename
? `- ${f.previous_filename} -> ${f.filename}`
: `- ${f.filename}`,
)
.join("\n");
const body = [
"This repository does not allow modifying `AGENTS.md` in pull requests.",
"",
"Detected changes:",
changedList,
"",
"Please revert these changes and open a new PR without touching `AGENTS.md`.",
].join("\n");
try {
await github.rest.issues.createComment({
owner,
repo,
issue_number: prNumber,
body,
});
} catch (error) {
core.warning(`Failed to comment on PR #${prNumber}: ${error.message}`);
}
await github.rest.pulls.update({
owner,
repo,
pull_number: prNumber,
state: "closed",
});
core.setFailed("PR modifies AGENTS.md");

View File

@@ -0,0 +1,73 @@
name: auto-retarget-main-pr-to-dev
on:
pull_request_target:
types:
- opened
- reopened
- edited
branches:
- main
permissions:
contents: read
issues: write
pull-requests: write
jobs:
retarget:
if: github.actor != 'github-actions[bot]'
runs-on: ubuntu-latest
steps:
- name: Retarget PR base to dev
uses: actions/github-script@v7
with:
script: |
const pr = context.payload.pull_request;
const prNumber = pr.number;
const { owner, repo } = context.repo;
const baseRef = pr.base?.ref;
const headRef = pr.head?.ref;
const desiredBase = "dev";
if (baseRef !== "main") {
core.info(`PR #${prNumber} base is ${baseRef}; nothing to do.`);
return;
}
if (headRef === desiredBase) {
core.info(`PR #${prNumber} is ${desiredBase} -> main; skipping retarget.`);
return;
}
core.info(`Retargeting PR #${prNumber} base from ${baseRef} to ${desiredBase}.`);
try {
await github.rest.pulls.update({
owner,
repo,
pull_number: prNumber,
base: desiredBase,
});
} catch (error) {
core.setFailed(`Failed to retarget PR #${prNumber} to ${desiredBase}: ${error.message}`);
return;
}
const body = [
`This pull request targeted \`${baseRef}\`.`,
"",
`The base branch has been automatically changed to \`${desiredBase}\`.`,
].join("\n");
try {
await github.rest.issues.createComment({
owner,
repo,
issue_number: prNumber,
body,
});
} catch (error) {
core.warning(`Failed to comment on PR #${prNumber}: ${error.message}`);
}

9
.gitignore vendored
View File

@@ -37,15 +37,16 @@ GEMINI.md
# Tooling metadata # Tooling metadata
.vscode/* .vscode/*
.worktrees/
.codex/* .codex/*
.claude/* .claude/*
.gemini/* .gemini/*
.serena/* .serena/*
.agent/* .agent/*
.agents/* .agents/*
.agents/*
.opencode/* .opencode/*
.idea/* .idea/*
.beads/*
.bmad/* .bmad/*
_bmad/* _bmad/*
_bmad-output/* _bmad-output/*
@@ -54,4 +55,10 @@ _bmad-output/*
# macOS # macOS
.DS_Store .DS_Store
._* ._*
# Opencode
.beads/
.opencode/
.cli-proxy-api/
.venv/
*.bak *.bak

58
AGENTS.md Normal file
View File

@@ -0,0 +1,58 @@
# AGENTS.md
Go 1.26+ proxy server providing OpenAI/Gemini/Claude/Codex compatible APIs with OAuth and round-robin load balancing.
## Repository
- GitHub: https://github.com/router-for-me/CLIProxyAPI
## Commands
```bash
gofmt -w . # Format (required after Go changes)
go build -o cli-proxy-api ./cmd/server # Build
go run ./cmd/server # Run dev server
go test ./... # Run all tests
go test -v -run TestName ./path/to/pkg # Run single test
go build -o test-output ./cmd/server && rm test-output # Verify compile (REQUIRED after changes)
```
- Common flags: `--config <path>`, `--tui`, `--standalone`, `--local-model`, `--no-browser`, `--oauth-callback-port <port>`
## Config
- Default config: `config.yaml` (template: `config.example.yaml`)
- `.env` is auto-loaded from the working directory
- Auth material defaults under `auths/`
- Storage backends: file-based default; optional Postgres/git/object store (`PGSTORE_*`, `GITSTORE_*`, `OBJECTSTORE_*`)
## Architecture
- `cmd/server/` — Server entrypoint
- `internal/api/` — Gin HTTP API (routes, middleware, modules)
- `internal/api/modules/amp/` — Amp integration (Amp-style routes + reverse proxy)
- `internal/thinking/` — Main thinking/reasoning pipeline. `ApplyThinking()` (apply.go) parses suffixes (`suffix.go`, suffix overrides body), normalizes config to canonical `ThinkingConfig` (`types.go`), normalizes and validates centrally (`validate.go`/`convert.go`), then applies provider-specific output via `ProviderApplier`. Do not break this "canonical representation → per-provider translation" architecture.
- `internal/runtime/executor/` — Per-provider runtime executors (incl. Codex WebSocket)
- `internal/translator/` — Provider protocol translators (and shared `common`)
- `internal/registry/` — Model registry + remote updater (`StartModelsUpdater`); `--local-model` disables remote updates
- `internal/store/` — Storage implementations and secret resolution
- `internal/managementasset/` — Config snapshots and management assets
- `internal/cache/` — Request signature caching
- `internal/watcher/` — Config hot-reload and watchers
- `internal/wsrelay/` — WebSocket relay sessions
- `internal/usage/` — Usage and token accounting
- `internal/tui/` — Bubbletea terminal UI (`--tui`, `--standalone`)
- `sdk/cliproxy/` — Embeddable SDK entry (service/builder/watchers/pipeline)
- `test/` — Cross-module integration tests
## Code Conventions
- Keep changes small and simple (KISS)
- Comments in English only
- If editing code that already contains non-English comments, translate them to English (dont add new non-English comments)
- For user-visible strings, keep the existing language used in that file/area
- New Markdown docs should be in English unless the file is explicitly language-specific (e.g. `README_CN.md`)
- As a rule, do not make standalone changes to `internal/translator/`. You may modify it only as part of broader changes elsewhere.
- If a task requires changing only `internal/translator/`, run `gh repo view --json viewerPermission -q .viewerPermission` to confirm you have `WRITE`, `MAINTAIN`, or `ADMIN`. If you do, you may proceed; otherwise, file a GitHub issue including the goal, rationale, and the intended implementation code, then stop further work.
- `internal/runtime/executor/` should contain executors and their unit tests only. Place any helper/supporting files under `internal/runtime/executor/helps/`.
- Follow `gofmt`; keep imports goimports-style; wrap errors with context where helpful
- Do not use `log.Fatal`/`log.Fatalf` (terminates the process); prefer returning errors and logging via logrus
- Shadowed variables: use method suffix (`errStart := server.Start()`)
- Wrap defer errors: `defer func() { if err := f.Close(); err != nil { log.Errorf(...) } }()`
- Use logrus structured logging; avoid leaking secrets/tokens in logs
- Avoid panics in HTTP handlers; prefer logged errors and meaningful HTTP status codes
- Timeouts are allowed only during credential acquisition; after an upstream connection is established, do not set timeouts for any subsequent network behavior. Intentional exceptions that must remain allowed are the Codex websocket liveness deadlines in `internal/runtime/executor/codex_websockets_executor.go`, the wsrelay session deadlines in `internal/wsrelay/session.go`, the management APICall timeout in `internal/api/handlers/management/api_tools.go`, and the `cmd/fetch_antigravity_models` utility timeouts

View File

@@ -1,6 +1,6 @@
# CLIProxyAPI Plus # CLIProxyAPI Plus
[English](README.md) | 中文 | [日本語](README_JA.md) [English](README.md) | 中文
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。 这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。

View File

@@ -1,187 +0,0 @@
# CLI Proxy API
[English](README.md) | [中文](README_CN.md) | 日本語
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
OAuth経由でOpenAI CodexGPTモデルおよびClaude Codeもサポートしています。
ローカルまたはマルチアカウントのCLIアクセスを、OpenAIResponses含む/Gemini/Claude互換のクライアントやSDKで利用できます。
## スポンサー
[![z.ai](https://assets.router-for.me/english-5-0.jpg)](https://z.ai/subscribe?ic=8JVLJQFSKB)
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7およびGLM-5はProユーザーのみ利用可能モデルを10以上の人気AIコーディングツールClaude Code、Cline、Roo Codeなどで利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
GLM CODING PLANを10%割引で取得https://z.ai/subscribe?ic=8JVLJQFSKB
---
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>PackyCodeのスポンサーシップに感謝しますPackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>AICodeMirrorのスポンサーシップに感謝しますAICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引がありますCLIProxyAPIユーザー向けの特別特典<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
</tr>
<tr>
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたしますBmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割90% OFF</b> という驚異的な価格でご利用いただけます!</td>
</tr>
</tbody>
</table>
## 概要
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
- OAuthログインによるOpenAI CodexサポートGPTモデル
- OAuthログインによるClaude Codeサポート
- OAuthログインによるQwen Codeサポート
- OAuthログインによるiFlowサポート
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
- ストリーミングおよび非ストリーミングレスポンス
- 関数呼び出し/ツールのサポート
- マルチモーダル入力サポート(テキストと画像)
- ラウンドロビン負荷分散による複数アカウント対応Gemini、OpenAI、Claude、QwenおよびiFlow
- シンプルなCLI認証フローGemini、OpenAI、Claude、QwenおよびiFlow
- Generative Language APIキーのサポート
- AI Studioビルドのマルチアカウント負荷分散
- Gemini CLIのマルチアカウント負荷分散
- Claude Codeのマルチアカウント負荷分散
- Qwen Codeのマルチアカウント負荷分散
- iFlowのマルチアカウント負荷分散
- OpenAI Codexのマルチアカウント負荷分散
- 設定によるOpenAI互換アップストリームプロバイダーOpenRouter
- プロキシ埋め込み用の再利用可能なGo SDK`docs/sdk-usage.md`を参照)
## はじめに
CLIProxyAPIガイド[https://help.router-for.me/](https://help.router-for.me/)
## 管理API
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
## Amp CLIサポート
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます
- Ampの APIパターン用のプロバイダールートエイリアス`/api/provider/{provider}/v1...`
- OAuth認証およびアカウント機能用の管理プロキシ
- 自動ルーティングによるスマートモデルフォールバック
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5``claude-sonnet-4`
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
## SDKドキュメント
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
- カスタムプロバイダーの例:`examples/custom-provider`
## コントリビューション
コントリビューションを歓迎しますお気軽にPull Requestを送ってください。
1. リポジトリをフォーク
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`
3. 変更をコミット(`git commit -m 'Add some amazing feature'`
4. ブランチにプッシュ(`git push origin feature/amazing-feature`
5. Pull Requestを作成
## 関連プロジェクト
CLIProxyAPIをベースにした以下のプロジェクトがあります
### [vibeproxy](https://github.com/automazeio/vibeproxy)
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデルGemini、Codex、Antigravityを即座に切り替えるCLIラッパー - APIキー不要
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
CLIProxyAPI管理用のmacOSネイティブGUIOAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
### [Quotio](https://github.com/nguyenphutrong/quotio)
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
### [CodMate](https://github.com/loocor/CodMate)
CLI AIセッションCodex、Claude Code、Gemini CLIを管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替えMain / Plus、自動ダウンロードおよび自動更新に対応
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LANローカルエリアネットワークを介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
> [!NOTE]
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
## その他の選択肢
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです
### [9Router](https://github.com/decolua/9router)
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換OpenAI/Claude/Gemini/Ollama、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツールCursor、Claude Code、Cline、RooCodeのサポートをゼロから構築 - APIキー不要
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイですスマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
> [!NOTE]
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
## ライセンス
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。

BIN
assets/lingtrue.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

View File

@@ -26,6 +26,7 @@ import (
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
@@ -188,7 +189,7 @@ func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
httpReq.Close = true httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Authorization", "Bearer "+accessToken)
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64") httpReq.Header.Set("User-Agent", misc.AntigravityUserAgent())
httpClient := &http.Client{Timeout: 30 * time.Second} httpClient := &http.Client{Timeout: 30 * time.Second}
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil { if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {

View File

@@ -99,6 +99,7 @@ func main() {
var codeBuddyLogin bool var codeBuddyLogin bool
var projectID string var projectID string
var vertexImport string var vertexImport string
var vertexImportPrefix string
var configPath string var configPath string
var password string var password string
var tuiMode bool var tuiMode bool
@@ -139,6 +140,7 @@ func main() {
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
flag.StringVar(&password, "password", "", "") flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
@@ -188,6 +190,7 @@ func main() {
gitStoreRemoteURL string gitStoreRemoteURL string
gitStoreUser string gitStoreUser string
gitStorePassword string gitStorePassword string
gitStoreBranch string
gitStoreLocalPath string gitStoreLocalPath string
gitStoreInst *store.GitTokenStore gitStoreInst *store.GitTokenStore
gitStoreRoot string gitStoreRoot string
@@ -257,6 +260,9 @@ func main() {
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok { if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
gitStoreLocalPath = value gitStoreLocalPath = value
} }
if value, ok := lookupEnv("GITSTORE_GIT_BRANCH", "gitstore_git_branch"); ok {
gitStoreBranch = value
}
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok { if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
useObjectStore = true useObjectStore = true
objectStoreEndpoint = value objectStoreEndpoint = value
@@ -391,7 +397,7 @@ func main() {
} }
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore") gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
authDir := filepath.Join(gitStoreRoot, "auths") authDir := filepath.Join(gitStoreRoot, "auths")
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword) gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword, gitStoreBranch)
gitStoreInst.SetBaseDir(authDir) gitStoreInst.SetBaseDir(authDir)
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil { if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
log.Errorf("failed to prepare git token store: %v", errRepo) log.Errorf("failed to prepare git token store: %v", errRepo)
@@ -510,7 +516,7 @@ func main() {
if vertexImport != "" { if vertexImport != "" {
// Handle Vertex service account import // Handle Vertex service account import
cmd.DoVertexImport(cfg, vertexImport) cmd.DoVertexImport(cfg, vertexImport, vertexImportPrefix)
} else if login { } else if login {
// Handle Google/Gemini login // Handle Google/Gemini login
cmd.DoLogin(cfg, projectID, options) cmd.DoLogin(cfg, projectID, options)
@@ -596,6 +602,7 @@ func main() {
if standalone { if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it. // Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath) managementasset.StartAutoUpdater(context.Background(), configFilePath)
misc.StartAntigravityVersionUpdater(context.Background())
if !localModel { if !localModel {
registry.StartModelsUpdater(context.Background()) registry.StartModelsUpdater(context.Background())
} }
@@ -671,6 +678,7 @@ func main() {
} else { } else {
// Start the main proxy service // Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath) managementasset.StartAutoUpdater(context.Background(), configFilePath)
misc.StartAntigravityVersionUpdater(context.Background())
if !localModel { if !localModel {
registry.StartModelsUpdater(context.Background()) registry.StartModelsUpdater(context.Background())
} }

View File

@@ -92,10 +92,14 @@ max-retry-credentials: 0
# Maximum wait time in seconds for a cooled-down credential before triggering a retry. # Maximum wait time in seconds for a cooled-down credential before triggering a retry.
max-retry-interval: 30 max-retry-interval: 30
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
disable-cooling: false
# Quota exceeded behavior # Quota exceeded behavior
quota-exceeded: quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
antigravity-credits: true # Whether to retry Antigravity quota_exhausted 429s once with enabledCreditTypes=["GOOGLE_ONE_AI"]
# Routing strategy for selecting credentials when multiple match. # Routing strategy for selecting credentials when multiple match.
routing: routing:
@@ -104,6 +108,10 @@ routing:
# When true, enable authentication for the WebSocket API (/v1/ws). # When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false ws-auth: false
# When true, enable Gemini CLI internal endpoints (/v1internal:*).
# Default is false for safety.
enable-gemini-cli-endpoint: false
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts. # When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
nonstream-keepalive-interval: 0 nonstream-keepalive-interval: 0
@@ -177,6 +185,8 @@ nonstream-keepalive-interval: 0
# - "API" # - "API"
# - "proxy" # - "proxy"
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request # cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# experimental-cch-signing: false # optional: default is false; when true, sign the final /v1/messages body using the current Claude Code cch algorithm
# # keep this disabled unless you explicitly need the behavior, so upstream seed changes fall back to legacy proxy behavior
# Default headers for Claude API requests. Update when Claude Code releases new versions. # Default headers for Claude API requests. Update when Claude Code releases new versions.
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks # In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
@@ -313,6 +323,10 @@ nonstream-keepalive-interval: 0
# These aliases rename model IDs for both model listing and request routing. # These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi. # Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode. # NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
# you select the protocol surface, but inference backend selection can still follow the resolved
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
# You can repeat the same name with different aliases to expose multiple client model names. # You can repeat the same name with different aliases to expose multiple client model names.
# oauth-model-alias: # oauth-model-alias:
# antigravity: # antigravity:

1
go.mod
View File

@@ -83,6 +83,7 @@ require (
github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pierrec/xxHash v0.1.5
github.com/pjbgf/sha1cd v0.5.0 // indirect github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect github.com/rs/xid v1.5.0 // indirect

2
go.sum
View File

@@ -154,6 +154,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pierrec/xxHash v0.1.5 h1:n/jBpwTHiER4xYvK3/CdPVnLDPchj8eTJFFLUb4QHBo=
github.com/pierrec/xxHash v0.1.5/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0= github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM= github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=

View File

@@ -13,6 +13,7 @@ import (
"github.com/fxamacker/cbor/v2" "github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
@@ -700,6 +701,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" { if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr) proxyCandidates = append(proxyCandidates, proxyStr)
} }
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
} }
if h != nil && h.cfg != nil { if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" { if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
@@ -722,6 +728,123 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
return clone return clone
} }
type apiKeyConfigEntry interface {
GetAPIKey() string
GetBaseURL() string
}
func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T {
if auth == nil || len(entries) == 0 {
return nil
}
attrKey, attrBase := "", ""
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range entries {
entry := &entries[i]
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range entries {
entry := &entries[i]
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
return entry
}
}
}
return nil
}
func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string {
if cfg == nil || auth == nil {
return ""
}
authKind, authAccount := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") {
return ""
}
attrs := auth.Attributes
compatName := ""
providerKey := ""
if len(attrs) > 0 {
compatName = strings.TrimSpace(attrs["compat_name"])
providerKey = strings.TrimSpace(attrs["provider_key"])
}
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName)
}
switch strings.ToLower(strings.TrimSpace(auth.Provider)) {
case "gemini":
if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
case "claude":
if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
case "codex":
if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
}
return ""
}
func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string {
if cfg == nil || auth == nil {
return ""
}
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return ""
}
candidates := make([]string, 0, 3)
if v := strings.TrimSpace(compatName); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(providerKey); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(auth.Provider); v != "" {
candidates = append(candidates, v)
}
for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i]
for _, candidate := range candidates {
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
for j := range compat.APIKeyEntries {
entry := &compat.APIKeyEntries[j]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) {
return strings.TrimSpace(entry.ProxyURL)
}
}
return ""
}
}
}
return ""
}
func buildProxyTransport(proxyStr string) *http.Transport { func buildProxyTransport(proxyStr string) *http.Transport {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr) transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil { if errBuild != nil {

View File

@@ -58,6 +58,105 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
} }
} }
func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) {
t.Parallel()
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
GeminiKey: []config.GeminiKey{{
APIKey: "gemini-key",
ProxyURL: "http://gemini-proxy.example.com:8080",
}},
ClaudeKey: []config.ClaudeKey{{
APIKey: "claude-key",
ProxyURL: "http://claude-proxy.example.com:8080",
}},
CodexKey: []config.CodexKey{{
APIKey: "codex-key",
ProxyURL: "http://codex-proxy.example.com:8080",
}},
OpenAICompatibility: []config.OpenAICompatibility{{
Name: "bohe",
BaseURL: "https://bohe.example.com",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{
APIKey: "compat-key",
ProxyURL: "http://compat-proxy.example.com:8080",
}},
}},
},
}
cases := []struct {
name string
auth *coreauth.Auth
wantProxy string
}{
{
name: "gemini",
auth: &coreauth.Auth{
Provider: "gemini",
Attributes: map[string]string{"api_key": "gemini-key"},
},
wantProxy: "http://gemini-proxy.example.com:8080",
},
{
name: "claude",
auth: &coreauth.Auth{
Provider: "claude",
Attributes: map[string]string{"api_key": "claude-key"},
},
wantProxy: "http://claude-proxy.example.com:8080",
},
{
name: "codex",
auth: &coreauth.Auth{
Provider: "codex",
Attributes: map[string]string{"api_key": "codex-key"},
},
wantProxy: "http://codex-proxy.example.com:8080",
},
{
name: "openai-compatibility",
auth: &coreauth.Auth{
Provider: "bohe",
Attributes: map[string]string{
"api_key": "compat-key",
"compat_name": "bohe",
"provider_key": "bohe",
},
},
wantProxy: "http://compat-proxy.example.com:8080",
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
transport := h.apiCallTransport(tc.auth)
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != tc.wantProxy {
t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy)
}
})
}
}
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) { func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
t.Parallel() t.Parallel()

View File

@@ -1047,6 +1047,7 @@ func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Aut
auth.Runtime = existing.Runtime auth.Runtime = existing.Runtime
} }
} }
coreauth.ApplyCustomHeadersFromMetadata(auth)
return auth, nil return auth, nil
} }
@@ -1129,7 +1130,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled}) c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
} }
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file. // PatchAuthFileFields updates editable fields (prefix, proxy_url, headers, priority, note) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) { func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil { if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
@@ -1137,11 +1138,12 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
} }
var req struct { var req struct {
Name string `json:"name"` Name string `json:"name"`
Prefix *string `json:"prefix"` Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"` ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"` Headers map[string]string `json:"headers"`
Note *string `json:"note"` Priority *int `json:"priority"`
Note *string `json:"note"`
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
@@ -1177,13 +1179,107 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
changed := false changed := false
if req.Prefix != nil { if req.Prefix != nil {
targetAuth.Prefix = *req.Prefix prefix := strings.TrimSpace(*req.Prefix)
targetAuth.Prefix = prefix
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if prefix == "" {
delete(targetAuth.Metadata, "prefix")
} else {
targetAuth.Metadata["prefix"] = prefix
}
changed = true changed = true
} }
if req.ProxyURL != nil { if req.ProxyURL != nil {
targetAuth.ProxyURL = *req.ProxyURL proxyURL := strings.TrimSpace(*req.ProxyURL)
targetAuth.ProxyURL = proxyURL
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if proxyURL == "" {
delete(targetAuth.Metadata, "proxy_url")
} else {
targetAuth.Metadata["proxy_url"] = proxyURL
}
changed = true changed = true
} }
if len(req.Headers) > 0 {
existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(targetAuth.Metadata)
nextHeaders := make(map[string]string, len(existingHeaders))
for k, v := range existingHeaders {
nextHeaders[k] = v
}
headerChanged := false
for key, value := range req.Headers {
name := strings.TrimSpace(key)
if name == "" {
continue
}
val := strings.TrimSpace(value)
attrKey := "header:" + name
if val == "" {
if _, ok := nextHeaders[name]; ok {
delete(nextHeaders, name)
headerChanged = true
}
if targetAuth.Attributes != nil {
if _, ok := targetAuth.Attributes[attrKey]; ok {
headerChanged = true
}
}
continue
}
if prev, ok := nextHeaders[name]; !ok || prev != val {
headerChanged = true
}
nextHeaders[name] = val
if targetAuth.Attributes != nil {
if prev, ok := targetAuth.Attributes[attrKey]; !ok || prev != val {
headerChanged = true
}
} else {
headerChanged = true
}
}
if headerChanged {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if targetAuth.Attributes == nil {
targetAuth.Attributes = make(map[string]string)
}
for key, value := range req.Headers {
name := strings.TrimSpace(key)
if name == "" {
continue
}
val := strings.TrimSpace(value)
attrKey := "header:" + name
if val == "" {
delete(nextHeaders, name)
delete(targetAuth.Attributes, attrKey)
continue
}
nextHeaders[name] = val
targetAuth.Attributes[attrKey] = val
}
if len(nextHeaders) == 0 {
delete(targetAuth.Metadata, "headers")
} else {
metaHeaders := make(map[string]any, len(nextHeaders))
for k, v := range nextHeaders {
metaHeaders[k] = v
}
targetAuth.Metadata["headers"] = metaHeaders
}
changed = true
}
}
if req.Priority != nil || req.Note != nil { if req.Priority != nil || req.Note != nil {
if targetAuth.Metadata == nil { if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any) targetAuth.Metadata = make(map[string]any)
@@ -2138,9 +2234,6 @@ func (h *Handler) RequestGitLabToken(c *gin.Context) {
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct) metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
metadata["auth_kind"] = "oauth" metadata["auth_kind"] = "oauth"
metadata["oauth_client_id"] = clientID metadata["oauth_client_id"] = clientID
if clientSecret != "" {
metadata["oauth_client_secret"] = clientSecret
}
metadata["username"] = strings.TrimSpace(user.Username) metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" { if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email metadata["email"] = email

View File

@@ -0,0 +1,164 @@
package management
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestPatchAuthFileFields_MergeHeadersAndDeleteEmptyValues(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
record := &coreauth.Auth{
ID: "test.json",
FileName: "test.json",
Provider: "claude",
Attributes: map[string]string{
"path": "/tmp/test.json",
"header:X-Old": "old",
"header:X-Remove": "gone",
},
Metadata: map[string]any{
"type": "claude",
"headers": map[string]any{
"X-Old": "old",
"X-Remove": "gone",
},
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
body := `{"name":"test.json","prefix":"p1","proxy_url":"http://proxy.local","headers":{"X-Old":"new","X-New":"v","X-Remove":" ","X-Nope":""}}`
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
ctx.Request = req
h.PatchAuthFileFields(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
updated, ok := manager.GetByID("test.json")
if !ok || updated == nil {
t.Fatalf("expected auth record to exist after patch")
}
if updated.Prefix != "p1" {
t.Fatalf("prefix = %q, want %q", updated.Prefix, "p1")
}
if updated.ProxyURL != "http://proxy.local" {
t.Fatalf("proxy_url = %q, want %q", updated.ProxyURL, "http://proxy.local")
}
if updated.Metadata == nil {
t.Fatalf("expected metadata to be non-nil")
}
if got, _ := updated.Metadata["prefix"].(string); got != "p1" {
t.Fatalf("metadata.prefix = %q, want %q", got, "p1")
}
if got, _ := updated.Metadata["proxy_url"].(string); got != "http://proxy.local" {
t.Fatalf("metadata.proxy_url = %q, want %q", got, "http://proxy.local")
}
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
if !ok {
raw, _ := json.Marshal(updated.Metadata["headers"])
t.Fatalf("metadata.headers = %T (%s), want map[string]any", updated.Metadata["headers"], string(raw))
}
if got := headersMeta["X-Old"]; got != "new" {
t.Fatalf("metadata.headers.X-Old = %#v, want %q", got, "new")
}
if got := headersMeta["X-New"]; got != "v" {
t.Fatalf("metadata.headers.X-New = %#v, want %q", got, "v")
}
if _, ok := headersMeta["X-Remove"]; ok {
t.Fatalf("expected metadata.headers.X-Remove to be deleted")
}
if _, ok := headersMeta["X-Nope"]; ok {
t.Fatalf("expected metadata.headers.X-Nope to be absent")
}
if got := updated.Attributes["header:X-Old"]; got != "new" {
t.Fatalf("attrs header:X-Old = %q, want %q", got, "new")
}
if got := updated.Attributes["header:X-New"]; got != "v" {
t.Fatalf("attrs header:X-New = %q, want %q", got, "v")
}
if _, ok := updated.Attributes["header:X-Remove"]; ok {
t.Fatalf("expected attrs header:X-Remove to be deleted")
}
if _, ok := updated.Attributes["header:X-Nope"]; ok {
t.Fatalf("expected attrs header:X-Nope to be absent")
}
}
func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
record := &coreauth.Auth{
ID: "noop.json",
FileName: "noop.json",
Provider: "claude",
Attributes: map[string]string{
"path": "/tmp/noop.json",
"header:X-Kee": "1",
},
Metadata: map[string]any{
"type": "claude",
"headers": map[string]any{
"X-Kee": "1",
},
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
body := `{"name":"noop.json","note":"hello","headers":{}}`
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
ctx.Request = req
h.PatchAuthFileFields(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
updated, ok := manager.GetByID("noop.json")
if !ok || updated == nil {
t.Fatalf("expected auth record to exist after patch")
}
if got := updated.Attributes["header:X-Kee"]; got != "1" {
t.Fatalf("attrs header:X-Kee = %q, want %q", got, "1")
}
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
if !ok {
t.Fatalf("expected metadata.headers to remain a map, got %T", updated.Metadata["headers"])
}
if got := headersMeta["X-Kee"]; got != "1" {
t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1")
}
}

View File

@@ -214,19 +214,46 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) {
func (h *Handler) DeleteGeminiKey(c *gin.Context) { func (h *Handler) DeleteGeminiKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" { if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) if baseRaw, okBase := c.GetQuery("base-url"); okBase {
for _, v := range h.cfg.GeminiKey { base := strings.TrimSpace(baseRaw)
if v.APIKey != val { out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
for _, v := range h.cfg.GeminiKey {
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
continue
}
out = append(out, v) out = append(out, v)
} }
if len(out) != len(h.cfg.GeminiKey) {
h.cfg.GeminiKey = out
h.cfg.SanitizeGeminiKeys()
h.persist(c)
} else {
c.JSON(404, gin.H{"error": "item not found"})
}
return
} }
if len(out) != len(h.cfg.GeminiKey) {
h.cfg.GeminiKey = out matchIndex := -1
h.cfg.SanitizeGeminiKeys() matchCount := 0
h.persist(c) for i := range h.cfg.GeminiKey {
} else { if strings.TrimSpace(h.cfg.GeminiKey[i].APIKey) == val {
matchCount++
if matchIndex == -1 {
matchIndex = i
}
}
}
if matchCount == 0 {
c.JSON(404, gin.H{"error": "item not found"}) c.JSON(404, gin.H{"error": "item not found"})
return
} }
if matchCount > 1 {
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
return
}
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:matchIndex], h.cfg.GeminiKey[matchIndex+1:]...)
h.cfg.SanitizeGeminiKeys()
h.persist(c)
return return
} }
if idxStr := c.Query("index"); idxStr != "" { if idxStr := c.Query("index"); idxStr != "" {
@@ -335,14 +362,39 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) {
} }
func (h *Handler) DeleteClaudeKey(c *gin.Context) { func (h *Handler) DeleteClaudeKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" { if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) if baseRaw, okBase := c.GetQuery("base-url"); okBase {
for _, v := range h.cfg.ClaudeKey { base := strings.TrimSpace(baseRaw)
if v.APIKey != val { out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
for _, v := range h.cfg.ClaudeKey {
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
continue
}
out = append(out, v) out = append(out, v)
} }
h.cfg.ClaudeKey = out
h.cfg.SanitizeClaudeKeys()
h.persist(c)
return
}
matchIndex := -1
matchCount := 0
for i := range h.cfg.ClaudeKey {
if strings.TrimSpace(h.cfg.ClaudeKey[i].APIKey) == val {
matchCount++
if matchIndex == -1 {
matchIndex = i
}
}
}
if matchCount > 1 {
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
return
}
if matchIndex != -1 {
h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:matchIndex], h.cfg.ClaudeKey[matchIndex+1:]...)
} }
h.cfg.ClaudeKey = out
h.cfg.SanitizeClaudeKeys() h.cfg.SanitizeClaudeKeys()
h.persist(c) h.persist(c)
return return
@@ -601,13 +653,38 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" { if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey)) if baseRaw, okBase := c.GetQuery("base-url"); okBase {
for _, v := range h.cfg.VertexCompatAPIKey { base := strings.TrimSpace(baseRaw)
if v.APIKey != val { out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
for _, v := range h.cfg.VertexCompatAPIKey {
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
continue
}
out = append(out, v) out = append(out, v)
} }
h.cfg.VertexCompatAPIKey = out
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
matchIndex := -1
matchCount := 0
for i := range h.cfg.VertexCompatAPIKey {
if strings.TrimSpace(h.cfg.VertexCompatAPIKey[i].APIKey) == val {
matchCount++
if matchIndex == -1 {
matchIndex = i
}
}
}
if matchCount > 1 {
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
return
}
if matchIndex != -1 {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:matchIndex], h.cfg.VertexCompatAPIKey[matchIndex+1:]...)
} }
h.cfg.VertexCompatAPIKey = out
h.cfg.SanitizeVertexCompatKeys() h.cfg.SanitizeVertexCompatKeys()
h.persist(c) h.persist(c)
return return
@@ -919,14 +996,39 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
} }
func (h *Handler) DeleteCodexKey(c *gin.Context) { func (h *Handler) DeleteCodexKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" { if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) if baseRaw, okBase := c.GetQuery("base-url"); okBase {
for _, v := range h.cfg.CodexKey { base := strings.TrimSpace(baseRaw)
if v.APIKey != val { out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
for _, v := range h.cfg.CodexKey {
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
continue
}
out = append(out, v) out = append(out, v)
} }
h.cfg.CodexKey = out
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
matchIndex := -1
matchCount := 0
for i := range h.cfg.CodexKey {
if strings.TrimSpace(h.cfg.CodexKey[i].APIKey) == val {
matchCount++
if matchIndex == -1 {
matchIndex = i
}
}
}
if matchCount > 1 {
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
return
}
if matchIndex != -1 {
h.cfg.CodexKey = append(h.cfg.CodexKey[:matchIndex], h.cfg.CodexKey[matchIndex+1:]...)
} }
h.cfg.CodexKey = out
h.cfg.SanitizeCodexKeys() h.cfg.SanitizeCodexKeys()
h.persist(c) h.persist(c)
return return

View File

@@ -0,0 +1,172 @@
package management
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func writeTestConfigFile(t *testing.T) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if errWrite := os.WriteFile(path, []byte("{}\n"), 0o600); errWrite != nil {
t.Fatalf("failed to write test config: %v", errWrite)
}
return path
}
func TestDeleteGeminiKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
h := &Handler{
cfg: &config.Config{
GeminiKey: []config.GeminiKey{
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
},
},
configFilePath: writeTestConfigFile(t),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key", nil)
h.DeleteGeminiKey(c)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
if got := len(h.cfg.GeminiKey); got != 2 {
t.Fatalf("gemini keys len = %d, want 2", got)
}
}
func TestDeleteGeminiKey_DeletesOnlyMatchingBaseURL(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
h := &Handler{
cfg: &config.Config{
GeminiKey: []config.GeminiKey{
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
},
},
configFilePath: writeTestConfigFile(t),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key&base-url=https://a.example.com", nil)
h.DeleteGeminiKey(c)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
if got := len(h.cfg.GeminiKey); got != 1 {
t.Fatalf("gemini keys len = %d, want 1", got)
}
if got := h.cfg.GeminiKey[0].BaseURL; got != "https://b.example.com" {
t.Fatalf("remaining base-url = %q, want %q", got, "https://b.example.com")
}
}
func TestDeleteClaudeKey_DeletesEmptyBaseURLWhenExplicitlyProvided(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
h := &Handler{
cfg: &config.Config{
ClaudeKey: []config.ClaudeKey{
{APIKey: "shared-key", BaseURL: ""},
{APIKey: "shared-key", BaseURL: "https://claude.example.com"},
},
},
configFilePath: writeTestConfigFile(t),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/claude-api-key?api-key=shared-key&base-url=", nil)
h.DeleteClaudeKey(c)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
if got := len(h.cfg.ClaudeKey); got != 1 {
t.Fatalf("claude keys len = %d, want 1", got)
}
if got := h.cfg.ClaudeKey[0].BaseURL; got != "https://claude.example.com" {
t.Fatalf("remaining base-url = %q, want %q", got, "https://claude.example.com")
}
}
func TestDeleteVertexCompatKey_DeletesOnlyMatchingBaseURL(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
h := &Handler{
cfg: &config.Config{
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
},
},
configFilePath: writeTestConfigFile(t),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/vertex-api-key?api-key=shared-key&base-url=https://b.example.com", nil)
h.DeleteVertexCompatKey(c)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
if got := len(h.cfg.VertexCompatAPIKey); got != 1 {
t.Fatalf("vertex keys len = %d, want 1", got)
}
if got := h.cfg.VertexCompatAPIKey[0].BaseURL; got != "https://a.example.com" {
t.Fatalf("remaining base-url = %q, want %q", got, "https://a.example.com")
}
}
func TestDeleteCodexKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
h := &Handler{
cfg: &config.Config{
CodexKey: []config.CodexKey{
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
},
},
configFilePath: writeTestConfigFile(t),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/codex-api-key?api-key=shared-key", nil)
h.DeleteCodexKey(c)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
if got := len(h.cfg.CodexKey); got != 2 {
t.Fatalf("codex keys len = %d, want 2", got)
}
}

View File

@@ -15,6 +15,8 @@ import (
) )
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE" const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes. // RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct { type RequestInfo struct {
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
if len(apiResponse) > 0 { if len(apiResponse) > 0 {
_ = w.streamWriter.WriteAPIResponse(apiResponse) _ = w.streamWriter.WriteAPIResponse(apiResponse)
} }
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
if len(apiWebsocketTimeline) > 0 {
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
}
if err := w.streamWriter.Close(); err != nil { if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil w.streamWriter = nil
return err return err
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil return nil
} }
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog) return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
} }
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string { func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
return data return data
} }
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
if !isExist {
return nil
}
data, ok := apiTimeline.([]byte)
if !ok || len(data) == 0 {
return nil
}
return bytes.Clone(data)
}
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time { func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP") ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
if !isExist { if !isExist {
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
} }
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte { func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil { if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist { return body
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
} }
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 { if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body return w.requestInfo.Body
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
return nil return nil
} }
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
return body
}
if w.body == nil || w.body.Len() == 0 {
return nil
}
return bytes.Clone(w.body.Bytes())
}
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
}
func extractBodyOverride(c *gin.Context, key string) []byte {
if c == nil {
return nil
}
bodyOverride, isExist := c.Get(key)
if !isExist {
return nil
}
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil { if w.requestInfo == nil {
return nil return nil
} }
if loggerWithOptions, ok := w.logger.(interface { if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok { }); ok {
return loggerWithOptions.LogRequestWithOptions( return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL, w.requestInfo.URL,
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode, statusCode,
headers, headers,
body, body,
websocketTimeline,
apiRequestBody, apiRequestBody,
apiResponseBody, apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors, apiResponseErrors,
forceLog, forceLog,
w.requestInfo.RequestID, w.requestInfo.RequestID,
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode, statusCode,
headers, headers,
body, body,
websocketTimeline,
apiRequestBody, apiRequestBody,
apiResponseBody, apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors, apiResponseErrors,
w.requestInfo.RequestID, w.requestInfo.RequestID,
w.requestInfo.Timestamp, w.requestInfo.Timestamp,

View File

@@ -1,10 +1,14 @@
package middleware package middleware
import ( import (
"bytes"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
) )
func TestExtractRequestBodyPrefersOverride(t *testing.T) { func TestExtractRequestBodyPrefersOverride(t *testing.T) {
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder) c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{} wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
c.Set(requestBodyOverrideContextKey, "override-as-string") c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c) body := wrapper.extractRequestBody(c)
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string") t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
} }
} }
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
wrapper.body.WriteString("original-response")
body := wrapper.extractResponseBody(c)
if string(body) != "original-response" {
t.Fatalf("response body = %q, want %q", string(body), "original-response")
}
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
body = wrapper.extractResponseBody(c)
if string(body) != "override-response" {
t.Fatalf("response body = %q, want %q", string(body), "override-response")
}
body[0] = 'X'
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
t.Fatalf("response override should be cloned, got %q", string(got))
}
}
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
body := wrapper.extractResponseBody(c)
if string(body) != "override-response-as-string" {
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
}
}
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
override := []byte("body-override")
c.Set(requestBodyOverrideContextKey, override)
body := extractBodyOverride(c, requestBodyOverrideContextKey)
if !bytes.Equal(body, override) {
t.Fatalf("body override = %q, want %q", string(body), string(override))
}
body[0] = 'X'
if !bytes.Equal(override, []byte("body-override")) {
t.Fatalf("override mutated: %q", string(override))
}
}
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
if got := wrapper.extractWebsocketTimeline(c); got != nil {
t.Fatalf("expected nil websocket timeline, got %q", string(got))
}
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
body := wrapper.extractWebsocketTimeline(c)
if string(body) != "timeline" {
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
}
}
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
streamWriter := &testStreamingLogWriter{}
wrapper := &ResponseWriterWrapper{
ResponseWriter: c.Writer,
logger: &testRequestLogger{enabled: true},
requestInfo: &RequestInfo{
URL: "/v1/responses",
Method: "POST",
Headers: map[string][]string{"Content-Type": {"application/json"}},
RequestID: "req-1",
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
},
isStreaming: true,
streamWriter: streamWriter,
}
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
if err := wrapper.Finalize(c); err != nil {
t.Fatalf("Finalize error: %v", err)
}
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
}
if !streamWriter.closed {
t.Fatal("expected stream writer to be closed")
}
}
type testRequestLogger struct {
enabled bool
}
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
return nil
}
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
return &testStreamingLogWriter{}, nil
}
func (l *testRequestLogger) IsEnabled() bool {
return l.enabled
}
type testStreamingLogWriter struct {
apiWebsocketTimeline []byte
closed bool
}
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
func (w *testStreamingLogWriter) Close() error {
w.closed = true
return nil
}

View File

@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
return return
} }
// Sanitize request body: remove thinking blocks with invalid signatures
// to prevent upstream API 400 errors
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
// Restore the body for the handler to read // Restore the body for the handler to read
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
@@ -249,6 +253,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
rewriter := NewResponseRewriter(c.Writer, modelName) rewriter := NewResponseRewriter(c.Writer, modelName)
rewriter.suppressThinking = true
c.Writer = rewriter c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths // Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c) filterAntropicBetaHeader(c)
@@ -259,10 +264,17 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
} else if len(providers) > 0 { } else if len(providers) > 0 {
// Log: Using local provider (free) // Log: Using local provider (free)
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
// Wrap with ResponseRewriter for local providers too, because upstream
// proxies (e.g. NewAPI) may return a different model name and lack
// Amp-required fields like thinking.signature.
rewriter := NewResponseRewriter(c.Writer, modelName)
rewriter.suppressThinking = providerName != "claude"
c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths // Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c) filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c) handler(c)
rewriter.Flush()
} else { } else {
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response // No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))

View File

@@ -129,11 +129,11 @@ func TestModifyResponse_GzipScenarios(t *testing.T) {
wantCE: "", wantCE: "",
}, },
{ {
name: "skips_non_2xx_status", name: "decompresses_non_2xx_status_when_gzip_detected",
header: http.Header{}, header: http.Header{},
body: good, body: good,
status: 404, status: 404,
wantBody: good, wantBody: goodJSON,
wantCE: "", wantCE: "",
}, },
} }

View File

@@ -2,6 +2,8 @@ package amp
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt"
"net/http" "net/http"
"strings" "strings"
@@ -12,15 +14,17 @@ import (
) )
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body // ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
// It's used to rewrite model names in responses when model mapping is used // It is used to rewrite model names in responses when model mapping is used
// and to keep Amp-compatible response shapes.
type ResponseRewriter struct { type ResponseRewriter struct {
gin.ResponseWriter gin.ResponseWriter
body *bytes.Buffer body *bytes.Buffer
originalModel string originalModel string
isStreaming bool isStreaming bool
suppressThinking bool
} }
// NewResponseRewriter creates a new response rewriter for model name substitution // NewResponseRewriter creates a new response rewriter for model name substitution.
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter { func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
return &ResponseRewriter{ return &ResponseRewriter{
ResponseWriter: w, ResponseWriter: w,
@@ -33,15 +37,15 @@ const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
func looksLikeSSEChunk(data []byte) bool { func looksLikeSSEChunk(data []byte) bool {
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered. // Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
// Heuristics are intentionally simple and cheap. // We conservatively detect SSE by checking for "data:" / "event:" at the start of any line.
return bytes.Contains(data, []byte("data:")) || for _, line := range bytes.Split(data, []byte("\n")) {
bytes.Contains(data, []byte("event:")) || trimmed := bytes.TrimSpace(line)
bytes.Contains(data, []byte("message_start")) || if bytes.HasPrefix(trimmed, []byte("data:")) ||
bytes.Contains(data, []byte("message_delta")) || bytes.HasPrefix(trimmed, []byte("event:")) {
bytes.Contains(data, []byte("content_block_start")) || return true
bytes.Contains(data, []byte("content_block_delta")) || }
bytes.Contains(data, []byte("content_block_stop")) || }
bytes.Contains(data, []byte("\n\n")) return false
} }
func (rw *ResponseRewriter) enableStreaming(reason string) error { func (rw *ResponseRewriter) enableStreaming(reason string) error {
@@ -95,7 +99,8 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
} }
if rw.isStreaming { if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) rewritten := rw.rewriteStreamChunk(data)
n, err := rw.ResponseWriter.Write(rewritten)
if err == nil { if err == nil {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
@@ -106,7 +111,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
return rw.body.Write(data) return rw.body.Write(data)
} }
// Flush writes the buffered response with model names rewritten
func (rw *ResponseRewriter) Flush() { func (rw *ResponseRewriter) Flush() {
if rw.isStreaming { if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
@@ -115,40 +119,79 @@ func (rw *ResponseRewriter) Flush() {
return return
} }
if rw.body.Len() > 0 { if rw.body.Len() > 0 {
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil { rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
// Update Content-Length to match the rewritten body size, since
// signature injection and model name changes alter the payload length.
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err) log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
} }
} }
} }
// modelFieldPaths lists all JSON paths where model name may appear
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"} var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON // ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility // in API responses so that the Amp TUI does not crash on P.signature.length.
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { func ensureAmpSignature(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected for index, block := range gjson.GetBytes(data, "content").Array() {
// The Amp client struggles when both thinking and tool_use blocks are present blockType := block.Get("type").String()
if blockType != "tool_use" && blockType != "thinking" {
continue
}
signaturePath := fmt.Sprintf("content.%d.signature", index)
if gjson.GetBytes(data, signaturePath).Exists() {
continue
}
var err error
data, err = sjson.SetBytes(data, signaturePath, "")
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
break
}
}
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
var err error
data, err = sjson.SetBytes(data, "content_block.signature", "")
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
}
}
return data
}
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
if !rw.suppressThinking {
return data
}
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() { if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`) filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() { if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int() originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int() filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount { if originalCount > filteredCount {
var err error var err error
data, err = sjson.SetBytes(data, "content", filtered.Value()) data, err = sjson.SetBytes(data, "content", filtered.Value())
if err != nil { if err != nil {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err) log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
} }
} }
} }
} }
return data
}
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
}
if rw.originalModel == "" { if rw.originalModel == "" {
return data return data
} }
@@ -160,24 +203,164 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
return data return data
} }
// rewriteStreamChunk rewrites model names in SSE stream chunks
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.originalModel == "" { lines := bytes.Split(chunk, []byte("\n"))
return chunk var out [][]byte
i := 0
for i < len(lines) {
line := lines[i]
trimmed := bytes.TrimSpace(line)
// Case 1: "event:" line - look ahead for its "data:" line
if bytes.HasPrefix(trimmed, []byte("event: ")) {
// Scan forward past blank lines to find the data: line
dataIdx := -1
for j := i + 1; j < len(lines); j++ {
t := bytes.TrimSpace(lines[j])
if len(t) == 0 {
continue
}
if bytes.HasPrefix(t, []byte("data: ")) {
dataIdx = j
}
break
}
if dataIdx >= 0 {
// Found event+data pair - process through rewriter
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
rewritten := rw.rewriteStreamEvent(jsonData)
if rewritten == nil {
i = dataIdx + 1
continue
}
// Emit event line
out = append(out, line)
// Emit blank lines between event and data
for k := i + 1; k < dataIdx; k++ {
out = append(out, lines[k])
}
// Emit rewritten data
out = append(out, append([]byte("data: "), rewritten...))
i = dataIdx + 1
continue
}
}
// No data line found (orphan event from cross-chunk split)
// Pass it through as-is - the data will arrive in the next chunk
out = append(out, line)
i++
continue
}
// Case 2: standalone "data:" line (no preceding event: in this chunk)
if bytes.HasPrefix(trimmed, []byte("data: ")) {
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
rewritten := rw.rewriteStreamEvent(jsonData)
if rewritten != nil {
out = append(out, append([]byte("data: "), rewritten...))
}
i++
continue
}
}
// Case 3: everything else
out = append(out, line)
i++
} }
// SSE format: "data: {json}\n\n" return bytes.Join(out, []byte("\n"))
lines := bytes.Split(chunk, []byte("\n")) }
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) { // rewriteStreamEvent processes a single JSON event in the SSE stream.
jsonData := bytes.TrimPrefix(line, []byte("data: ")) // It rewrites model names and ensures signature fields exist.
if len(jsonData) > 0 && jsonData[0] == '{' { // NOTE: streaming mode does NOT suppress thinking blocks - they are
// Rewrite JSON in the data line // passed through with signature injection to avoid breaking SSE index
rewritten := rw.rewriteModelInResponse(jsonData) // alignment and TUI rendering.
lines[i] = append([]byte("data: "), rewritten...) func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Inject empty signature where needed
data = ensureAmpSignature(data)
// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.originalModel)
} }
} }
} }
return bytes.Join(lines, []byte("\n")) return data
}
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
// and strips the proxy-injected "signature" field from tool_use blocks in the messages
// array before forwarding to the upstream API.
// This prevents 400 errors from the API which requires valid signatures on thinking
// blocks and does not accept a signature field on tool_use blocks.
func SanitizeAmpRequestBody(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
modified := false
for msgIdx, msg := range messages.Array() {
if msg.Get("role").String() != "assistant" {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
var keepBlocks []interface{}
contentModified := false
for _, block := range content.Array() {
blockType := block.Get("type").String()
if blockType == "thinking" {
sig := block.Get("signature")
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
contentModified = true
continue
}
}
// Use raw JSON to prevent float64 rounding of large integers in tool_use inputs
blockRaw := []byte(block.Raw)
if blockType == "tool_use" && block.Get("signature").Exists() {
blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature")
contentModified = true
}
// sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw
keepBlocks = append(keepBlocks, json.RawMessage(blockRaw))
}
if contentModified {
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
var err error
if len(keepBlocks) == 0 {
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
} else {
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
}
if err != nil {
log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err)
continue
}
modified = true
}
}
if modified {
log.Debugf("Amp RequestSanitizer: sanitized request body")
}
return body
} }

View File

@@ -1,6 +1,7 @@
package amp package amp
import ( import (
"strings"
"testing" "testing"
) )
@@ -100,6 +101,80 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
} }
} }
func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) {
rw := &ResponseRewriter{}
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
result := rw.rewriteStreamChunk(chunk)
// Streaming mode preserves thinking blocks (does NOT suppress them)
// to avoid breaking SSE index alignment and TUI rendering
if !contains(result, []byte(`"content_block":{"type":"thinking"`)) {
t.Fatalf("expected thinking content_block_start to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"delta":{"type":"thinking_delta"`)) {
t.Fatalf("expected thinking_delta to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"type":"content_block_stop","index":0`)) {
t.Fatalf("expected content_block_stop for thinking block to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"content_block":{"type":"tool_use"`)) {
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
}
// Signature should be injected into both thinking and tool_use blocks
if count := strings.Count(string(result), `"signature":""`); count != 2 {
t.Fatalf("expected 2 signature injections, but got %d in %s", count, string(result))
}
}
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte("drop-whitespace")) {
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
}
if contains(result, []byte("drop-number")) {
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
}
if !contains(result, []byte("keep-valid")) {
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
}
if !contains(result, []byte("keep-text")) {
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
}
}
func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte(`"signature":""`)) {
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
}
if !contains(result, []byte(`"valid-sig"`)) {
t.Fatalf("expected thinking signature to remain, got %s", string(result))
}
if !contains(result, []byte(`"tool_use"`)) {
t.Fatalf("expected tool_use block to remain, got %s", string(result))
}
}
func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte("drop-me")) {
t.Fatalf("expected invalid thinking block to be removed, got %s", string(result))
}
if contains(result, []byte(`"signature"`)) {
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
}
if !contains(result, []byte(`"tool_use"`)) {
t.Fatalf("expected tool_use block to remain, got %s", string(result))
}
}
func contains(data, substr []byte) bool { func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ { for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) { if string(data[i:i+len(substr)]) == string(substr) {

View File

@@ -323,6 +323,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// setupRoutes configures the API routes for the server. // setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers. // It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() { func (s *Server) setupRoutes() {
s.engine.GET("/healthz", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
s.engine.GET("/management.html", s.serveManagementControlPanel) s.engine.GET("/management.html", s.serveManagementControlPanel)
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers) openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers) geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
@@ -569,6 +573,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel) mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel) mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
mgmt.GET("/copilot-quota", s.mgmt.GetCopilotQuota)
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys) mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys) mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys) mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)

View File

@@ -1,6 +1,7 @@
package api package api
import ( import (
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -46,6 +47,28 @@ func newTestServer(t *testing.T) *Server {
return NewServer(cfg, authManager, accessManager, configPath) return NewServer(cfg, authManager, accessManager, configPath)
} }
func TestHealthz(t *testing.T) {
server := newTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
rr := httptest.NewRecorder()
server.engine.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String())
}
var resp struct {
Status string `json:"status"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String())
}
if resp.Status != "ok" {
t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok")
}
}
func TestAmpProviderModelRoutes(t *testing.T) { func TestAmpProviderModelRoutes(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@@ -172,6 +195,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
nil,
true, true,
"issue-1711", "issue-1711",
time.Now(), time.Now(),

View File

@@ -88,7 +88,7 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
"client_id": {ClientID}, "client_id": {ClientID},
"response_type": {"code"}, "response_type": {"code"},
"redirect_uri": {RedirectURI}, "redirect_uri": {RedirectURI},
"scope": {"org:create_api_key user:profile user:inference"}, "scope": {"user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"},
"code_challenge": {pkceCodes.CodeChallenge}, "code_challenge": {pkceCodes.CodeChallenge},
"code_challenge_method": {"S256"}, "code_challenge_method": {"S256"},
"state": {state}, "state": {state},

View File

@@ -235,6 +235,74 @@ type CopilotModelEntry struct {
Capabilities map[string]any `json:"capabilities,omitempty"` Capabilities map[string]any `json:"capabilities,omitempty"`
} }
// CopilotModelLimits holds the token limits returned by the Copilot /models API
// under capabilities.limits. These limits vary by account type (individual vs
// business) and are the authoritative source for enforcing prompt size.
type CopilotModelLimits struct {
// MaxContextWindowTokens is the total context window (prompt + output).
MaxContextWindowTokens int
// MaxPromptTokens is the hard limit on input/prompt tokens.
// Exceeding this triggers a 400 error from the Copilot API.
MaxPromptTokens int
// MaxOutputTokens is the maximum number of output/completion tokens.
MaxOutputTokens int
}
// Limits extracts the token limits from the model's capabilities map.
// Returns nil if no limits are available or the structure is unexpected.
//
// Expected Copilot API shape:
//
// "capabilities": {
// "limits": {
// "max_context_window_tokens": 200000,
// "max_prompt_tokens": 168000,
// "max_output_tokens": 32000
// }
// }
func (e *CopilotModelEntry) Limits() *CopilotModelLimits {
if e.Capabilities == nil {
return nil
}
limitsRaw, ok := e.Capabilities["limits"]
if !ok {
return nil
}
limitsMap, ok := limitsRaw.(map[string]any)
if !ok {
return nil
}
result := &CopilotModelLimits{
MaxContextWindowTokens: anyToInt(limitsMap["max_context_window_tokens"]),
MaxPromptTokens: anyToInt(limitsMap["max_prompt_tokens"]),
MaxOutputTokens: anyToInt(limitsMap["max_output_tokens"]),
}
// Only return if at least one field is populated.
if result.MaxContextWindowTokens == 0 && result.MaxPromptTokens == 0 && result.MaxOutputTokens == 0 {
return nil
}
return result
}
// anyToInt converts a JSON-decoded numeric value to int.
// Go's encoding/json decodes numbers into float64 when the target is any/interface{}.
func anyToInt(v any) int {
switch n := v.(type) {
case float64:
return int(n)
case float32:
return int(n)
case int:
return n
case int64:
return int(n)
default:
return 0
}
}
// CopilotModelsResponse represents the response from the Copilot /models endpoint. // CopilotModelsResponse represents the response from the Copilot /models endpoint.
type CopilotModelsResponse struct { type CopilotModelsResponse struct {
Data []CopilotModelEntry `json:"data"` Data []CopilotModelEntry `json:"data"`

View File

@@ -30,6 +30,10 @@ type VertexCredentialStorage struct {
// Type is the provider identifier stored alongside credentials. Always "vertex". // Type is the provider identifier stored alongside credentials. Always "vertex".
Type string `json:"type"` Type string `json:"type"`
// Prefix optionally namespaces models for this credential (e.g., "teamA").
// This results in model names like "teamA/gemini-2.0-flash".
Prefix string `json:"prefix,omitempty"`
} }
// SaveTokenToFile writes the credential payload to the given file path in JSON format. // SaveTokenToFile writes the credential payload to the given file path in JSON format.

View File

@@ -20,7 +20,7 @@ import (
// DoVertexImport imports a Google Cloud service account key JSON and persists // DoVertexImport imports a Google Cloud service account key JSON and persists
// it as a "vertex" provider credential. The file content is embedded in the auth // it as a "vertex" provider credential. The file content is embedded in the auth
// file to allow portable deployment across stores. // file to allow portable deployment across stores.
func DoVertexImport(cfg *config.Config, keyPath string) { func DoVertexImport(cfg *config.Config, keyPath string, prefix string) {
if cfg == nil { if cfg == nil {
cfg = &config.Config{} cfg = &config.Config{}
} }
@@ -62,13 +62,28 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
// Default location if not provided by user. Can be edited in the saved file later. // Default location if not provided by user. Can be edited in the saved file later.
location := "us-central1" location := "us-central1"
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID)) // Normalize and validate prefix: must be a single segment (no "/" allowed).
prefix = strings.TrimSpace(prefix)
prefix = strings.Trim(prefix, "/")
if prefix != "" && strings.Contains(prefix, "/") {
log.Errorf("vertex-import: prefix must be a single segment (no '/' allowed): %q", prefix)
return
}
// Include prefix in filename so importing the same project with different
// prefixes creates separate credential files instead of overwriting.
baseName := sanitizeFilePart(projectID)
if prefix != "" {
baseName = sanitizeFilePart(prefix) + "-" + baseName
}
fileName := fmt.Sprintf("vertex-%s.json", baseName)
// Build auth record // Build auth record
storage := &vertex.VertexCredentialStorage{ storage := &vertex.VertexCredentialStorage{
ServiceAccount: sa, ServiceAccount: sa,
ProjectID: projectID, ProjectID: projectID,
Email: email, Email: email,
Location: location, Location: location,
Prefix: prefix,
} }
metadata := map[string]any{ metadata := map[string]any{
"service_account": sa, "service_account": sa,
@@ -76,6 +91,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
"email": email, "email": email,
"location": location, "location": location,
"type": "vertex", "type": "vertex",
"prefix": prefix,
"label": labelForVertex(projectID, email), "label": labelForVertex(projectID, email),
} }
record := &coreauth.Auth{ record := &coreauth.Auth{

View File

@@ -211,6 +211,10 @@ type QuotaExceeded struct {
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded. // SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
// AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once
// on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"].
AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"`
} }
// RoutingConfig configures how credentials are selected for requests. // RoutingConfig configures how credentials are selected for requests.
@@ -257,8 +261,8 @@ type AmpCode struct {
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"` UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys. // UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used. // When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// If no match is found, falls back to UpstreamAPIKey (default behavior). // is used for the upstream Amp request.
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"` UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.) // RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
@@ -380,6 +384,11 @@ type ClaudeKey struct {
// Cloak configures request cloaking for non-Claude-Code clients. // Cloak configures request cloaking for non-Claude-Code clients.
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"` Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
// ExperimentalCCHSigning enables opt-in final-body cch signing for cloaked
// Claude /v1/messages requests. It is disabled by default so upstream seed
// changes do not alter the proxy's legacy behavior.
ExperimentalCCHSigning bool `yaml:"experimental-cch-signing,omitempty" json:"experimental-cch-signing,omitempty"`
} }
func (k ClaudeKey) GetAPIKey() string { return k.APIKey } func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
@@ -972,6 +981,7 @@ func (cfg *Config) SanitizeKiroKeys() {
} }
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials. // SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
// It uses API key + base URL as the uniqueness key.
func (cfg *Config) SanitizeGeminiKeys() { func (cfg *Config) SanitizeGeminiKeys() {
if cfg == nil { if cfg == nil {
return return
@@ -990,10 +1000,11 @@ func (cfg *Config) SanitizeGeminiKeys() {
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = NormalizeHeaders(entry.Headers) entry.Headers = NormalizeHeaders(entry.Headers)
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
if _, exists := seen[entry.APIKey]; exists { uniqueKey := entry.APIKey + "|" + entry.BaseURL
if _, exists := seen[uniqueKey]; exists {
continue continue
} }
seen[entry.APIKey] = struct{}{} seen[uniqueKey] = struct{}{}
out = append(out, entry) out = append(out, entry)
} }
cfg.GeminiKey = out cfg.GeminiKey = out

View File

@@ -9,6 +9,10 @@ type SDKConfig struct {
// ProxyURL is the URL of an optional proxy server to use for outbound requests. // ProxyURL is the URL of an optional proxy server to use for outbound requests.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"` ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
// Default is false for safety; when false, /v1internal:* requests are rejected.
EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"`
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview") // ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
// to target prefixed credentials. When false, unprefixed model requests may use prefixed // to target prefixed credentials. When false, unprefixed model requests may use prefixed
// credentials as well. // credentials as well.

View File

@@ -4,6 +4,7 @@
package logging package logging
import ( import (
"bufio"
"bytes" "bytes"
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
@@ -41,15 +42,17 @@ type RequestLogger interface {
// - statusCode: The response status code // - statusCode: The response status code
// - responseHeaders: The response headers // - responseHeaders: The response headers
// - response: The raw response data // - response: The raw response data
// - websocketTimeline: Optional downstream websocket event timeline
// - apiRequest: The API request data // - apiRequest: The API request data
// - apiResponse: The API response data // - apiResponse: The API response data
// - apiWebsocketTimeline: Optional upstream websocket event timeline
// - requestID: Optional request ID for log file naming // - requestID: Optional request ID for log file naming
// - requestTimestamp: When the request was received // - requestTimestamp: When the request was received
// - apiResponseTimestamp: When the API response was received // - apiResponseTimestamp: When the API response was received
// //
// Returns: // Returns:
// - error: An error if logging fails, nil otherwise // - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks. // LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
// //
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
// - error: An error if writing fails, nil otherwise // - error: An error if writing fails, nil otherwise
WriteAPIResponse(apiResponse []byte) error WriteAPIResponse(apiResponse []byte) error
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
// This should be called when upstream communication happened over websocket.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received. // SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
// //
// Parameters: // Parameters:
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
// //
// Returns: // Returns:
// - error: An error if logging fails, nil otherwise // - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp) return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
} }
// LogRequestWithOptions logs a request with optional forced logging behavior. // LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled. // The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp) return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
} }
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error { func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
if !l.enabled && !force { if !l.enabled && !force {
return nil return nil
} }
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
requestHeaders, requestHeaders,
body, body,
requestBodyPath, requestBodyPath,
websocketTimeline,
apiRequest, apiRequest,
apiResponse, apiResponse,
apiWebsocketTimeline,
apiResponseErrors, apiResponseErrors,
statusCode, statusCode,
responseHeaders, responseHeaders,
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
requestHeaders map[string][]string, requestHeaders map[string][]string,
requestBody []byte, requestBody []byte,
requestBodyPath string, requestBodyPath string,
websocketTimeline []byte,
apiRequest []byte, apiRequest []byte,
apiResponse []byte, apiResponse []byte,
apiWebsocketTimeline []byte,
apiResponseErrors []*interfaces.ErrorMessage, apiResponseErrors []*interfaces.ErrorMessage,
statusCode int, statusCode int,
responseHeaders map[string][]string, responseHeaders map[string][]string,
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
if requestTimestamp.IsZero() { if requestTimestamp.IsZero() {
requestTimestamp = time.Now() requestTimestamp = time.Now()
} }
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil { isWebsocketTranscript := hasSectionPayload(websocketTimeline)
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
return errWrite return errWrite
} }
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil { if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil { if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
return errWrite return errWrite
} }
if isWebsocketTranscript {
// Intentionally omit the generic downstream HTTP response section for websocket
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
// and appending a one-off upgrade response snapshot would dilute that transcript.
return nil
}
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true) return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
} }
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
body []byte, body []byte,
bodyPath string, bodyPath string,
timestamp time.Time, timestamp time.Time,
downstreamTransport string,
upstreamTransport string,
includeBody bool,
) error { ) error {
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil { if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
return errWrite return errWrite
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil { if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
return errWrite return errWrite
} }
if strings.TrimSpace(downstreamTransport) != "" {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
return errWrite
}
}
if strings.TrimSpace(upstreamTransport) != "" {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil { if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
return errWrite return errWrite
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
return errWrite return errWrite
} }
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
} }
} }
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
return errWrite return errWrite
} }
if !includeBody {
return nil
}
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil { if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
return errWrite return errWrite
} }
bodyTrailingNewlines := 1
if bodyPath != "" { if bodyPath != "" {
bodyFile, errOpen := os.Open(bodyPath) bodyFile, errOpen := os.Open(bodyPath)
if errOpen != nil { if errOpen != nil {
return errOpen return errOpen
} }
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil { tracker := &trailingNewlineTrackingWriter{writer: w}
written, errCopy := io.Copy(tracker, bodyFile)
if errCopy != nil {
_ = bodyFile.Close() _ = bodyFile.Close()
return errCopy return errCopy
} }
if written > 0 {
bodyTrailingNewlines = tracker.trailingNewlines
}
if errClose := bodyFile.Close(); errClose != nil { if errClose := bodyFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request body temp file") log.WithError(errClose).Warn("failed to close request body temp file")
} }
} else if _, errWrite := w.Write(body); errWrite != nil { } else if _, errWrite := w.Write(body); errWrite != nil {
return errWrite return errWrite
} else if len(body) > 0 {
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
} }
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
return errWrite return errWrite
} }
return nil return nil
} }
func countTrailingNewlinesBytes(payload []byte) int {
count := 0
for i := len(payload) - 1; i >= 0; i-- {
if payload[i] != '\n' {
break
}
count++
}
return count
}
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
missingNewlines := 3 - trailingNewlines
if missingNewlines <= 0 {
return nil
}
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
return errWrite
}
type trailingNewlineTrackingWriter struct {
writer io.Writer
trailingNewlines int
}
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
written, errWrite := t.writer.Write(payload)
if written > 0 {
writtenPayload := payload[:written]
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
if trailingNewlines == len(writtenPayload) {
t.trailingNewlines += trailingNewlines
} else {
t.trailingNewlines = trailingNewlines
}
}
return written, errWrite
}
func hasSectionPayload(payload []byte) bool {
return len(bytes.TrimSpace(payload)) > 0
}
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
if hasSectionPayload(websocketTimeline) {
return "websocket"
}
for key, values := range headers {
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
for _, value := range values {
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
return "websocket"
}
}
}
}
return "http"
}
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
hasWS := hasSectionPayload(apiWebsocketTimeline)
switch {
case hasHTTP && hasWS:
return "websocket+http"
case hasWS:
return "websocket"
case hasHTTP:
return "http"
default:
return ""
}
}
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error { func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
if len(payload) == 0 { if len(payload) == 0 {
return nil return nil
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := w.Write(payload); errWrite != nil { if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite return errWrite
} }
if !bytes.HasSuffix(payload, []byte("\n")) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
} else { } else {
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil { if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
return errWrite return errWrite
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := w.Write(payload); errWrite != nil { if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite return errWrite
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
return errWrite return errWrite
} }
return nil return nil
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil { if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
return errWrite return errWrite
} }
trailingNewlines := 1
if apiResponseErrors[i].Error != nil { if apiResponseErrors[i].Error != nil {
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil { errText := apiResponseErrors[i].Error.Error()
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
return errWrite return errWrite
} }
if errText != "" {
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
}
} }
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil { if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
return errWrite return errWrite
} }
} }
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
} }
} }
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { var bufferedReader *bufio.Reader
return errWrite if responseReader != nil {
bufferedReader = bufio.NewReader(responseReader)
}
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
} }
if responseReader != nil { if bufferedReader != nil {
if _, errCopy := io.Copy(w, responseReader); errCopy != nil { if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
return errCopy return errCopy
} }
} }
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
return nil return nil
} }
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
if reader == nil {
return false
}
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
return true
}
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
return true
}
return false
}
// formatLogContent creates the complete log content for non-streaming requests. // formatLogContent creates the complete log content for non-streaming requests.
// //
// Parameters: // Parameters:
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
// - method: The HTTP method // - method: The HTTP method
// - headers: The request headers // - headers: The request headers
// - body: The request body // - body: The request body
// - websocketTimeline: The downstream websocket event timeline
// - apiRequest: The API request data // - apiRequest: The API request data
// - apiResponse: The API response data // - apiResponse: The API response data
// - response: The raw response data // - response: The raw response data
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
// //
// Returns: // Returns:
// - string: The formatted log content // - string: The formatted log content
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string { func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
var content strings.Builder var content strings.Builder
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
// Request info // Request info
content.WriteString(l.formatRequestInfo(url, method, headers, body)) content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
if len(websocketTimeline) > 0 {
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
content.Write(websocketTimeline)
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
content.Write(websocketTimeline)
content.WriteString("\n")
}
content.WriteString("\n")
}
if len(apiWebsocketTimeline) > 0 {
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
content.Write(apiWebsocketTimeline)
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
content.Write(apiWebsocketTimeline)
content.WriteString("\n")
}
content.WriteString("\n")
}
if len(apiRequest) > 0 { if len(apiRequest) > 0 {
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) { if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
content.WriteString("\n") content.WriteString("\n")
} }
if isWebsocketTranscript {
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
// timeline sections instead of a generic downstream HTTP response block.
return content.String()
}
// Response section // Response section
content.WriteString("=== RESPONSE ===\n") content.WriteString("=== RESPONSE ===\n")
content.WriteString(fmt.Sprintf("Status: %d\n", status)) content.WriteString(fmt.Sprintf("Status: %d\n", status))
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
// //
// Returns: // Returns:
// - string: The formatted request information // - string: The formatted request information
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string { func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
var content strings.Builder var content strings.Builder
content.WriteString("=== REQUEST INFO ===\n") content.WriteString("=== REQUEST INFO ===\n")
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version)) content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
content.WriteString(fmt.Sprintf("URL: %s\n", url)) content.WriteString(fmt.Sprintf("URL: %s\n", url))
content.WriteString(fmt.Sprintf("Method: %s\n", method)) content.WriteString(fmt.Sprintf("Method: %s\n", method))
if strings.TrimSpace(downstreamTransport) != "" {
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
}
if strings.TrimSpace(upstreamTransport) != "" {
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
}
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano))) content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
content.WriteString("\n") content.WriteString("\n")
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
} }
content.WriteString("\n") content.WriteString("\n")
if !includeBody {
return content.String()
}
content.WriteString("=== REQUEST BODY ===\n") content.WriteString("=== REQUEST BODY ===\n")
content.Write(body) content.Write(body)
content.WriteString("\n\n") content.WriteString("\n\n")
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
// apiResponse stores the upstream API response data. // apiResponse stores the upstream API response data.
apiResponse []byte apiResponse []byte
// apiWebsocketTimeline stores the upstream websocket event timeline.
apiWebsocketTimeline []byte
// apiResponseTimestamp captures when the API response was received. // apiResponseTimestamp captures when the API response was received.
apiResponseTimestamp time.Time apiResponseTimestamp time.Time
} }
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
return nil return nil
} }
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
if len(apiWebsocketTimeline) == 0 {
return nil
}
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) { func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
if !timestamp.IsZero() { if !timestamp.IsZero() {
w.apiResponseTimestamp = timestamp w.apiResponseTimestamp = timestamp
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
// Close finalizes the log file and cleans up resources. // Close finalizes the log file and cleans up resources.
// It writes all buffered data to the file in the correct order: // It writes all buffered data to the file in the correct order:
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks) // API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
// //
// Returns: // Returns:
// - error: An error if closing fails, nil otherwise // - error: An error if closing fails, nil otherwise
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
} }
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error { func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil { if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
return errWrite return errWrite
} }
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil { if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
return nil return nil
} }
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
return nil
}
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {} func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
// Close is a no-op implementation that does nothing and always returns nil. // Close is a no-op implementation that does nothing and always returns nil.

View File

@@ -0,0 +1,151 @@
// Package misc provides miscellaneous utility functions for the CLI Proxy API server.
package misc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
antigravityReleasesURL = "https://antigravity-auto-updater-974169037036.us-central1.run.app/releases"
antigravityFallbackVersion = "1.21.9"
antigravityVersionCacheTTL = 6 * time.Hour
antigravityFetchTimeout = 10 * time.Second
)
type antigravityRelease struct {
Version string `json:"version"`
ExecutionID string `json:"execution_id"`
}
var (
cachedAntigravityVersion = antigravityFallbackVersion
antigravityVersionMu sync.RWMutex
antigravityVersionExpiry time.Time
antigravityUpdaterOnce sync.Once
)
// StartAntigravityVersionUpdater starts a background goroutine that periodically refreshes the cached antigravity version.
// This is intentionally decoupled from request execution to avoid blocking executors on version lookups.
func StartAntigravityVersionUpdater(ctx context.Context) {
antigravityUpdaterOnce.Do(func() {
go runAntigravityVersionUpdater(ctx)
})
}
func runAntigravityVersionUpdater(ctx context.Context) {
if ctx == nil {
ctx = context.Background()
}
ticker := time.NewTicker(antigravityVersionCacheTTL / 2)
defer ticker.Stop()
log.Infof("periodic antigravity version refresh started (interval=%s)", antigravityVersionCacheTTL/2)
refreshAntigravityVersion(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
refreshAntigravityVersion(ctx)
}
}
}
func refreshAntigravityVersion(ctx context.Context) {
version, errFetch := fetchAntigravityLatestVersion(ctx)
antigravityVersionMu.Lock()
defer antigravityVersionMu.Unlock()
now := time.Now()
if errFetch == nil {
cachedAntigravityVersion = version
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
log.WithField("version", version).Info("fetched latest antigravity version")
return
}
if cachedAntigravityVersion == "" || now.After(antigravityVersionExpiry) {
cachedAntigravityVersion = antigravityFallbackVersion
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
log.WithError(errFetch).Warn("failed to refresh antigravity version, using fallback version")
return
}
log.WithError(errFetch).Debug("failed to refresh antigravity version, keeping cached value")
}
// AntigravityLatestVersion returns the cached antigravity version refreshed by StartAntigravityVersionUpdater.
// It falls back to antigravityFallbackVersion if the cache is empty or stale.
func AntigravityLatestVersion() string {
antigravityVersionMu.RLock()
if cachedAntigravityVersion != "" && time.Now().Before(antigravityVersionExpiry) {
v := cachedAntigravityVersion
antigravityVersionMu.RUnlock()
return v
}
antigravityVersionMu.RUnlock()
return antigravityFallbackVersion
}
// AntigravityUserAgent returns the User-Agent string for antigravity requests
// using the latest version fetched from the releases API.
func AntigravityUserAgent() string {
return fmt.Sprintf("antigravity/%s darwin/arm64", AntigravityLatestVersion())
}
func fetchAntigravityLatestVersion(ctx context.Context) (string, error) {
if ctx == nil {
ctx = context.Background()
}
client := &http.Client{Timeout: antigravityFetchTimeout}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodGet, antigravityReleasesURL, nil)
if errReq != nil {
return "", fmt.Errorf("build antigravity releases request: %w", errReq)
}
resp, errDo := client.Do(httpReq)
if errDo != nil {
return "", fmt.Errorf("fetch antigravity releases: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.WithError(errClose).Warn("antigravity releases response body close error")
}
}()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("antigravity releases API returned status %d", resp.StatusCode)
}
var releases []antigravityRelease
if errDecode := json.NewDecoder(resp.Body).Decode(&releases); errDecode != nil {
return "", fmt.Errorf("decode antigravity releases response: %w", errDecode)
}
if len(releases) == 0 {
return "", errors.New("antigravity releases API returned empty list")
}
version := releases[0].Version
if version == "" {
return "", errors.New("antigravity releases API returned empty version")
}
return version, nil
}

View File

@@ -93,6 +93,54 @@ func GetAntigravityModels() []*ModelInfo {
func GetCodeBuddyModels() []*ModelInfo { func GetCodeBuddyModels() []*ModelInfo {
now := int64(1748044800) // 2025-05-24 now := int64(1748044800) // 2025-05-24
return []*ModelInfo{ return []*ModelInfo{
{
ID: "auto",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Auto",
Description: "Automatic model selection via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "glm-5v-turbo",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-5v Turbo",
Description: "GLM-5v Turbo via CodeBuddy",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "glm-5.1",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-5.1",
Description: "GLM-5.1 via CodeBuddy",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "glm-5.0-turbo",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-5.0 Turbo",
Description: "GLM-5.0 Turbo via CodeBuddy",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{ {
ID: "glm-5.0", ID: "glm-5.0",
Object: "model", Object: "model",
@@ -101,7 +149,7 @@ func GetCodeBuddyModels() []*ModelInfo {
Type: "codebuddy", Type: "codebuddy",
DisplayName: "GLM-5.0", DisplayName: "GLM-5.0",
Description: "GLM-5.0 via CodeBuddy", Description: "GLM-5.0 via CodeBuddy",
ContextLength: 128000, ContextLength: 200000,
MaxCompletionTokens: 32768, MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
}, },
@@ -113,18 +161,18 @@ func GetCodeBuddyModels() []*ModelInfo {
Type: "codebuddy", Type: "codebuddy",
DisplayName: "GLM-4.7", DisplayName: "GLM-4.7",
Description: "GLM-4.7 via CodeBuddy", Description: "GLM-4.7 via CodeBuddy",
ContextLength: 128000, ContextLength: 200000,
MaxCompletionTokens: 32768, MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
}, },
{ {
ID: "minimax-m2.5", ID: "minimax-m2.7",
Object: "model", Object: "model",
Created: now, Created: now,
OwnedBy: "tencent", OwnedBy: "tencent",
Type: "codebuddy", Type: "codebuddy",
DisplayName: "MiniMax M2.5", DisplayName: "MiniMax M2.7",
Description: "MiniMax M2.5 via CodeBuddy", Description: "MiniMax M2.7 via CodeBuddy",
ContextLength: 200000, ContextLength: 200000,
MaxCompletionTokens: 32768, MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
@@ -137,10 +185,23 @@ func GetCodeBuddyModels() []*ModelInfo {
Type: "codebuddy", Type: "codebuddy",
DisplayName: "Kimi K2.5", DisplayName: "Kimi K2.5",
Description: "Kimi K2.5 via CodeBuddy", Description: "Kimi K2.5 via CodeBuddy",
ContextLength: 128000, ContextLength: 256000,
MaxCompletionTokens: 32768, MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
}, },
{
ID: "kimi-k2-thinking",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Kimi K2 Thinking",
Description: "Kimi K2 Thinking via CodeBuddy",
ContextLength: 256000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{ZeroAllowed: true},
SupportedEndpoints: []string{"/chat/completions"},
},
{ {
ID: "deepseek-v3-2-volc", ID: "deepseek-v3-2-volc",
Object: "model", Object: "model",
@@ -148,24 +209,11 @@ func GetCodeBuddyModels() []*ModelInfo {
OwnedBy: "tencent", OwnedBy: "tencent",
Type: "codebuddy", Type: "codebuddy",
DisplayName: "DeepSeek V3.2 (Volc)", DisplayName: "DeepSeek V3.2 (Volc)",
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)", Description: "DeepSeek V3.2 via CodeBuddy",
ContextLength: 128000, ContextLength: 128000,
MaxCompletionTokens: 32768, MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
}, },
{
ID: "hunyuan-2.0-thinking",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Hunyuan 2.0 Thinking",
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{ZeroAllowed: true},
SupportedEndpoints: []string{"/chat/completions"},
},
} }
} }
@@ -287,6 +335,13 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
return nil return nil
} }
// defaultCopilotClaudeContextLength is the conservative prompt token limit for
// Claude models accessed via the GitHub Copilot API. Individual accounts are
// capped at 128K; business accounts at 168K. When the dynamic /models API fetch
// succeeds, the real per-account limit overrides this value. This constant is
// only used as a safe fallback.
const defaultCopilotClaudeContextLength = 128000
// GetGitHubCopilotModels returns the available models for GitHub Copilot. // GetGitHubCopilotModels returns the available models for GitHub Copilot.
// These models are available through the GitHub Copilot API at api.githubcopilot.com. // These models are available through the GitHub Copilot API at api.githubcopilot.com.
func GetGitHubCopilotModels() []*ModelInfo { func GetGitHubCopilotModels() []*ModelInfo {
@@ -477,6 +532,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
SupportedEndpoints: []string{"/responses"}, SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
}, },
{
ID: "gpt-5.4-mini",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.4 mini",
Description: "OpenAI GPT-5.4 mini via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{ {
ID: "claude-haiku-4.5", ID: "claude-haiku-4.5",
Object: "model", Object: "model",
@@ -485,7 +553,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Claude Haiku 4.5", DisplayName: "Claude Haiku 4.5",
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot", Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
ContextLength: 200000, ContextLength: defaultCopilotClaudeContextLength,
MaxCompletionTokens: 64000, MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
}, },
@@ -497,7 +565,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Claude Opus 4.1", DisplayName: "Claude Opus 4.1",
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot", Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
ContextLength: 200000, ContextLength: defaultCopilotClaudeContextLength,
MaxCompletionTokens: 32000, MaxCompletionTokens: 32000,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
}, },
@@ -509,9 +577,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Claude Opus 4.5", DisplayName: "Claude Opus 4.5",
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot", Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
ContextLength: 200000, ContextLength: defaultCopilotClaudeContextLength,
MaxCompletionTokens: 64000, MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
}, },
{ {
ID: "claude-opus-4.6", ID: "claude-opus-4.6",
@@ -521,9 +590,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Claude Opus 4.6", DisplayName: "Claude Opus 4.6",
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot", Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
ContextLength: 200000, ContextLength: defaultCopilotClaudeContextLength,
MaxCompletionTokens: 64000, MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
}, },
{ {
ID: "claude-sonnet-4", ID: "claude-sonnet-4",
@@ -533,9 +603,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Claude Sonnet 4", DisplayName: "Claude Sonnet 4",
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot", Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
ContextLength: 200000, ContextLength: defaultCopilotClaudeContextLength,
MaxCompletionTokens: 64000, MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
}, },
{ {
ID: "claude-sonnet-4.5", ID: "claude-sonnet-4.5",
@@ -545,9 +616,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Claude Sonnet 4.5", DisplayName: "Claude Sonnet 4.5",
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot", Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
ContextLength: 200000, ContextLength: defaultCopilotClaudeContextLength,
MaxCompletionTokens: 64000, MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
}, },
{ {
ID: "claude-sonnet-4.6", ID: "claude-sonnet-4.6",
@@ -557,9 +629,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Claude Sonnet 4.6", DisplayName: "Claude Sonnet 4.6",
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot", Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
ContextLength: 200000, ContextLength: defaultCopilotClaudeContextLength,
MaxCompletionTokens: 64000, MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"}, SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
}, },
{ {
ID: "gemini-2.5-pro", ID: "gemini-2.5-pro",
@@ -571,6 +644,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Google Gemini 2.5 Pro via GitHub Copilot", Description: "Google Gemini 2.5 Pro via GitHub Copilot",
ContextLength: 1048576, ContextLength: 1048576,
MaxCompletionTokens: 65536, MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
}, },
{ {
ID: "gemini-3-pro-preview", ID: "gemini-3-pro-preview",
@@ -582,6 +656,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Google Gemini 3 Pro Preview via GitHub Copilot", Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
ContextLength: 1048576, ContextLength: 1048576,
MaxCompletionTokens: 65536, MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
}, },
{ {
ID: "gemini-3.1-pro-preview", ID: "gemini-3.1-pro-preview",
@@ -591,8 +666,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Gemini 3.1 Pro (Preview)", DisplayName: "Gemini 3.1 Pro (Preview)",
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot", Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
ContextLength: 1048576, ContextLength: 173000,
MaxCompletionTokens: 65536, MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
}, },
{ {
ID: "gemini-3-flash-preview", ID: "gemini-3-flash-preview",
@@ -602,8 +678,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot", Type: "github-copilot",
DisplayName: "Gemini 3 Flash (Preview)", DisplayName: "Gemini 3 Flash (Preview)",
Description: "Google Gemini 3 Flash Preview via GitHub Copilot", Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
ContextLength: 1048576, ContextLength: 173000,
MaxCompletionTokens: 65536, MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
}, },
{ {
ID: "grok-code-fast-1", ID: "grok-code-fast-1",

View File

@@ -0,0 +1,29 @@
package registry
import "testing"
func TestGitHubCopilotGeminiModelsAreChatOnly(t *testing.T) {
models := GetGitHubCopilotModels()
required := map[string]bool{
"gemini-2.5-pro": false,
"gemini-3-pro-preview": false,
"gemini-3.1-pro-preview": false,
"gemini-3-flash-preview": false,
}
for _, model := range models {
if _, ok := required[model.ID]; !ok {
continue
}
required[model.ID] = true
if len(model.SupportedEndpoints) != 1 || model.SupportedEndpoints[0] != "/chat/completions" {
t.Fatalf("model %q supported endpoints = %v, want [/chat/completions]", model.ID, model.SupportedEndpoints)
}
}
for modelID, found := range required {
if !found {
t.Fatalf("expected GitHub Copilot model %q in definitions", modelID)
}
}
}

View File

@@ -1177,6 +1177,16 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
"dynamic_allowed": model.Thinking.DynamicAllowed, "dynamic_allowed": model.Thinking.DynamicAllowed,
} }
} }
// Include context limits so Claude Code can manage conversation
// context correctly, especially for Copilot-proxied models whose
// real prompt limit (128K-168K) is much lower than the 1M window
// that Claude Code may assume for Opus 4.6 with 1M context enabled.
if model.ContextLength > 0 {
result["context_length"] = model.ContextLength
}
if model.MaxCompletionTokens > 0 {
result["max_completion_tokens"] = model.MaxCompletionTokens
}
return result return result
case "gemini": case "gemini":

View File

@@ -280,6 +280,7 @@
"dynamic_allowed": true, "dynamic_allowed": true,
"levels": [ "levels": [
"low", "low",
"medium",
"high" "high"
] ]
} }
@@ -554,6 +555,7 @@
"dynamic_allowed": true, "dynamic_allowed": true,
"levels": [ "levels": [
"low", "low",
"medium",
"high" "high"
] ]
} }
@@ -610,6 +612,8 @@
"dynamic_allowed": true, "dynamic_allowed": true,
"levels": [ "levels": [
"minimal", "minimal",
"low",
"medium",
"high" "high"
] ]
} }
@@ -838,6 +842,7 @@
"dynamic_allowed": true, "dynamic_allowed": true,
"levels": [ "levels": [
"low", "low",
"medium",
"high" "high"
] ]
} }
@@ -896,6 +901,8 @@
"dynamic_allowed": true, "dynamic_allowed": true,
"levels": [ "levels": [
"minimal", "minimal",
"low",
"medium",
"high" "high"
] ]
} }
@@ -1070,6 +1077,8 @@
"dynamic_allowed": true, "dynamic_allowed": true,
"levels": [ "levels": [
"minimal", "minimal",
"low",
"medium",
"high" "high"
] ]
} }
@@ -1371,6 +1380,75 @@
"xhigh" "xhigh"
] ]
} }
},
{
"id": "gpt-5.3-codex",
"object": "model",
"created": 1770307200,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.3 Codex",
"version": "gpt-5.3",
"description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.4",
"object": "model",
"created": 1772668800,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.4",
"version": "gpt-5.4",
"description": "Stable version of GPT 5.4",
"context_length": 1050000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.4-mini",
"object": "model",
"created": 1773705600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.4 Mini",
"version": "gpt-5.4-mini",
"description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
} }
], ],
"codex-team": [ "codex-team": [
@@ -1623,6 +1701,29 @@
"xhigh" "xhigh"
] ]
} }
},
{
"id": "gpt-5.4-mini",
"object": "model",
"created": 1773705600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.4 Mini",
"version": "gpt-5.4-mini",
"description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
} }
], ],
"codex-plus": [ "codex-plus": [
@@ -1898,6 +1999,29 @@
"xhigh" "xhigh"
] ]
} }
},
{
"id": "gpt-5.4-mini",
"object": "model",
"created": 1773705600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.4 Mini",
"version": "gpt-5.4-mini",
"description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
} }
], ],
"codex-pro": [ "codex-pro": [
@@ -2173,55 +2297,40 @@
"xhigh" "xhigh"
] ]
} }
},
{
"id": "gpt-5.4-mini",
"object": "model",
"created": 1773705600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.4 Mini",
"version": "gpt-5.4-mini",
"description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
} }
], ],
"qwen": [ "qwen": [
{
"id": "qwen3-coder-plus",
"object": "model",
"created": 1753228800,
"owned_by": "qwen",
"type": "qwen",
"display_name": "Qwen3 Coder Plus",
"version": "3.0",
"description": "Advanced code generation and understanding model",
"context_length": 32768,
"max_completion_tokens": 8192,
"supported_parameters": [
"temperature",
"top_p",
"max_tokens",
"stream",
"stop"
]
},
{
"id": "qwen3-coder-flash",
"object": "model",
"created": 1753228800,
"owned_by": "qwen",
"type": "qwen",
"display_name": "Qwen3 Coder Flash",
"version": "3.0",
"description": "Fast code generation model",
"context_length": 8192,
"max_completion_tokens": 2048,
"supported_parameters": [
"temperature",
"top_p",
"max_tokens",
"stream",
"stop"
]
},
{ {
"id": "coder-model", "id": "coder-model",
"object": "model", "object": "model",
"created": 1771171200, "created": 1771171200,
"owned_by": "qwen", "owned_by": "qwen",
"type": "qwen", "type": "qwen",
"display_name": "Qwen 3.5 Plus", "display_name": "Qwen 3.6 Plus",
"version": "3.5", "version": "3.6",
"description": "efficient hybrid model with leading coding performance", "description": "efficient hybrid model with leading coding performance",
"context_length": 1048576, "context_length": 1048576,
"max_completion_tokens": 65536, "max_completion_tokens": 65536,
@@ -2232,25 +2341,6 @@
"stream", "stream",
"stop" "stop"
] ]
},
{
"id": "vision-model",
"object": "model",
"created": 1758672000,
"owned_by": "qwen",
"type": "qwen",
"display_name": "Qwen3 Vision Model",
"version": "3.0",
"description": "Vision model model",
"context_length": 32768,
"max_completion_tokens": 2048,
"supported_parameters": [
"temperature",
"top_p",
"max_tokens",
"stream",
"stop"
]
} }
], ],
"iflow": [ "iflow": [
@@ -2639,11 +2729,12 @@
"context_length": 1048576, "context_length": 1048576,
"max_completion_tokens": 65535, "max_completion_tokens": 65535,
"thinking": { "thinking": {
"min": 128, "min": 1,
"max": 32768, "max": 65535,
"dynamic_allowed": true, "dynamic_allowed": true,
"levels": [ "levels": [
"low", "low",
"medium",
"high" "high"
] ]
} }
@@ -2659,11 +2750,12 @@
"context_length": 1048576, "context_length": 1048576,
"max_completion_tokens": 65535, "max_completion_tokens": 65535,
"thinking": { "thinking": {
"min": 128, "min": 1,
"max": 32768, "max": 65535,
"dynamic_allowed": true, "dynamic_allowed": true,
"levels": [ "levels": [
"low", "low",
"medium",
"high" "high"
] ]
} }

View File

@@ -14,7 +14,9 @@ import (
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay" "github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -46,8 +48,16 @@ func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Man
// Identifier returns the executor identifier. // Identifier returns the executor identifier.
func (e *AIStudioExecutor) Identifier() string { return "aistudio" } func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio). // PrepareRequest prepares the HTTP request for execution.
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { func (e *AIStudioExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil return nil
} }
@@ -66,6 +76,9 @@ func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.A
return nil, fmt.Errorf("aistudio executor: missing auth") return nil, fmt.Errorf("aistudio executor: missing auth")
} }
httpReq := req.WithContext(ctx) httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" { if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
return nil, fmt.Errorf("aistudio executor: request URL is empty") return nil, fmt.Errorf("aistudio executor: request URL is empty")
} }
@@ -115,8 +128,8 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
} }
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, false) translatedReq, body, err := e.translateRequest(req, opts, false)
if err != nil { if err != nil {
@@ -130,6 +143,11 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
Headers: http.Header{"Content-Type": []string{"application/json"}}, Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload, Body: body.payload,
} }
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -137,7 +155,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint, URL: endpoint,
Method: http.MethodPost, Method: http.MethodPost,
Headers: wsReq.Headers.Clone(), Headers: wsReq.Headers.Clone(),
@@ -151,17 +169,17 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
wsResp, err := e.relay.NonStream(ctx, authID, wsReq) wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
if len(wsResp.Body) > 0 { if len(wsResp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body) helps.AppendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
} }
if wsResp.Status < 200 || wsResp.Status >= 300 { if wsResp.Status < 200 || wsResp.Status >= 300 {
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)} return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
} }
reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) reporter.Publish(ctx, helps.ParseGeminiUsage(wsResp.Body))
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param) out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()} resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
@@ -174,8 +192,8 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
} }
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, true) translatedReq, body, err := e.translateRequest(req, opts, true)
if err != nil { if err != nil {
@@ -189,13 +207,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
Headers: http.Header{"Content-Type": []string{"application/json"}}, Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload, Body: body.payload,
} }
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
authID = auth.ID authID = auth.ID
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint, URL: endpoint,
Method: http.MethodPost, Method: http.MethodPost,
Headers: wsReq.Headers.Clone(), Headers: wsReq.Headers.Clone(),
@@ -208,24 +231,24 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}) })
wsStream, err := e.relay.Stream(ctx, authID, wsReq) wsStream, err := e.relay.Stream(ctx, authID, wsReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err return nil, err
} }
firstEvent, ok := <-wsStream firstEvent, ok := <-wsStream
if !ok { if !ok {
err = fmt.Errorf("wsrelay: stream closed before start") err = fmt.Errorf("wsrelay: stream closed before start")
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err return nil, err
} }
if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK { if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK {
metadataLogged := false metadataLogged := false
if firstEvent.Status > 0 { if firstEvent.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
metadataLogged = true metadataLogged = true
} }
var body bytes.Buffer var body bytes.Buffer
if len(firstEvent.Payload) > 0 { if len(firstEvent.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload) helps.AppendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
body.Write(firstEvent.Payload) body.Write(firstEvent.Payload)
} }
if firstEvent.Type == wsrelay.MessageTypeStreamEnd { if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
@@ -233,18 +256,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
} }
for event := range wsStream { for event := range wsStream {
if event.Err != nil { if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err) helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
if body.Len() == 0 { if body.Len() == 0 {
body.WriteString(event.Err.Error()) body.WriteString(event.Err.Error())
} }
break break
} }
if !metadataLogged && event.Status > 0 { if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true metadataLogged = true
} }
if len(event.Payload) > 0 { if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload) helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
body.Write(event.Payload) body.Write(event.Payload)
} }
if event.Type == wsrelay.MessageTypeStreamEnd { if event.Type == wsrelay.MessageTypeStreamEnd {
@@ -260,23 +283,23 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
metadataLogged := false metadataLogged := false
processEvent := func(event wsrelay.StreamEvent) bool { processEvent := func(event wsrelay.StreamEvent) bool {
if event.Err != nil { if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err) helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false return false
} }
switch event.Type { switch event.Type {
case wsrelay.MessageTypeStreamStart: case wsrelay.MessageTypeStreamStart:
if !metadataLogged && event.Status > 0 { if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true metadataLogged = true
} }
case wsrelay.MessageTypeStreamChunk: case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 { if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload) helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
filtered := FilterSSEUsageMetadata(event.Payload) filtered := helps.FilterSSEUsageMetadata(event.Payload)
if detail, ok := parseGeminiStreamUsage(filtered); ok { if detail, ok := helps.ParseGeminiStreamUsage(filtered); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param) lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param)
for i := range lines { for i := range lines {
@@ -288,21 +311,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return false return false
case wsrelay.MessageTypeHTTPResp: case wsrelay.MessageTypeHTTPResp:
if !metadataLogged && event.Status > 0 { if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true metadataLogged = true
} }
if len(event.Payload) > 0 { if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload) helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
} }
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param) lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
} }
reporter.publish(ctx, parseGeminiUsage(event.Payload)) reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
return false return false
case wsrelay.MessageTypeError: case wsrelay.MessageTypeError:
recordAPIResponseError(ctx, e.cfg, event.Err) helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false return false
} }
@@ -345,7 +368,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint, URL: endpoint,
Method: http.MethodPost, Method: http.MethodPost,
Headers: wsReq.Headers.Clone(), Headers: wsReq.Headers.Clone(),
@@ -358,12 +381,12 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
}) })
resp, err := e.relay.NonStream(ctx, authID, wsReq) resp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err return cliproxyexecutor.Response{}, err
} }
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
if len(resp.Body) > 0 { if len(resp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, resp.Body) helps.AppendAPIResponseChunk(ctx, e.cfg, resp.Body)
} }
if resp.Status < 200 || resp.Status >= 300 { if resp.Status < 200 || resp.Status >= 300 {
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
@@ -404,8 +427,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
return nil, translatedPayload{}, err return nil, translatedPayload{}, err
} }
payload = fixGeminiImageAspectRatio(baseModel, payload) payload = fixGeminiImageAspectRatio(baseModel, payload)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel) payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens") payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType") payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema") payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")

View File

@@ -24,6 +24,8 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -44,15 +46,44 @@ const (
antigravityGeneratePath = "/v1internal:generateContent" antigravityGeneratePath = "/v1internal:generateContent"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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" antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64" defaultAntigravityAgent = "antigravity/1.21.9 darwin/arm64" // fallback only; overridden at runtime by misc.AntigravityUserAgent()
antigravityAuthType = "antigravity" antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second refreshSkew = 3000 * time.Second
antigravityCreditsRetryTTL = 5 * time.Hour
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**" // systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
) )
type antigravity429Category string
const (
antigravity429Unknown antigravity429Category = "unknown"
antigravity429RateLimited antigravity429Category = "rate_limited"
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
)
var ( var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano())) randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex randSourceMutex sync.Mutex
antigravityCreditsExhaustedByAuth sync.Map
antigravityPreferCreditsByModel sync.Map
antigravityQuotaExhaustedKeywords = []string{
"quota_exhausted",
"quota exhausted",
}
antigravityCreditsExhaustedKeywords = []string{
"google_one_ai",
"insufficient credit",
"insufficient credits",
"not enough credit",
"not enough credits",
"credit exhausted",
"credits exhausted",
"credit balance",
"minimumcreditamountforusage",
"minimum credit amount for usage",
"minimum credit",
"resource has been exhausted",
}
) )
// AntigravityExecutor proxies requests to the antigravity upstream. // AntigravityExecutor proxies requests to the antigravity upstream.
@@ -113,7 +144,7 @@ func initAntigravityTransport() {
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
antigravityTransportOnce.Do(initAntigravityTransport) antigravityTransportOnce.Do(initAntigravityTransport)
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout) client := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
// If no transport is set, use the shared HTTP/1.1 transport. // If no transport is set, use the shared HTTP/1.1 transport.
if client.Transport == nil { if client.Transport == nil {
client.Transport = antigravityTransport client.Transport = antigravityTransport
@@ -183,6 +214,259 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
return httpClient.Do(httpReq) return httpClient.Do(httpReq)
} }
func injectEnabledCreditTypes(payload []byte) []byte {
if len(payload) == 0 {
return nil
}
if !gjson.ValidBytes(payload) {
return nil
}
updated, err := sjson.SetRawBytes(payload, "enabledCreditTypes", []byte(`["GOOGLE_ONE_AI"]`))
if err != nil {
return nil
}
return updated
}
func classifyAntigravity429(body []byte) antigravity429Category {
if len(body) == 0 {
return antigravity429Unknown
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityQuotaExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
return antigravity429QuotaExhausted
}
}
status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String())
if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") {
return antigravity429Unknown
}
details := gjson.GetBytes(body, "error.details")
if !details.Exists() || !details.IsArray() {
return antigravity429Unknown
}
for _, detail := range details.Array() {
if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" {
continue
}
reason := strings.TrimSpace(detail.Get("reason").String())
if strings.EqualFold(reason, "QUOTA_EXHAUSTED") {
return antigravity429QuotaExhausted
}
if strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED") {
return antigravity429RateLimited
}
}
return antigravity429Unknown
}
func antigravityHasQuotaResetDelayOrModelInfo(body []byte) bool {
if len(body) == 0 {
return false
}
details := gjson.GetBytes(body, "error.details")
if !details.Exists() || !details.IsArray() {
return false
}
for _, detail := range details.Array() {
if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" {
continue
}
if strings.TrimSpace(detail.Get("metadata.quotaResetDelay").String()) != "" {
return true
}
if strings.TrimSpace(detail.Get("metadata.model").String()) != "" {
return true
}
}
return false
}
func antigravityCreditsRetryEnabled(cfg *config.Config) bool {
return cfg != nil && cfg.QuotaExceeded.AntigravityCredits
}
func antigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) bool {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return false
}
value, ok := antigravityCreditsExhaustedByAuth.Load(auth.ID)
if !ok {
return false
}
until, ok := value.(time.Time)
if !ok || until.IsZero() {
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
return false
}
if !until.After(now) {
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
return false
}
return true
}
func markAntigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return
}
antigravityCreditsExhaustedByAuth.Store(auth.ID, now.Add(antigravityCreditsRetryTTL))
}
func clearAntigravityCreditsExhausted(auth *cliproxyauth.Auth) {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return
}
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
}
func antigravityPreferCreditsKey(auth *cliproxyauth.Auth, modelName string) string {
if auth == nil {
return ""
}
authID := strings.TrimSpace(auth.ID)
modelName = strings.TrimSpace(modelName)
if authID == "" || modelName == "" {
return ""
}
return authID + "|" + modelName
}
func antigravityShouldPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time) bool {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return false
}
value, ok := antigravityPreferCreditsByModel.Load(key)
if !ok {
return false
}
until, ok := value.(time.Time)
if !ok || until.IsZero() {
antigravityPreferCreditsByModel.Delete(key)
return false
}
if !until.After(now) {
antigravityPreferCreditsByModel.Delete(key)
return false
}
return true
}
func markAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time, retryAfter *time.Duration) {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return
}
until := now.Add(antigravityCreditsRetryTTL)
if retryAfter != nil && *retryAfter > 0 {
until = now.Add(*retryAfter)
}
antigravityPreferCreditsByModel.Store(key, until)
}
func clearAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string) {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return
}
antigravityPreferCreditsByModel.Delete(key)
}
func shouldMarkAntigravityCreditsExhausted(statusCode int, body []byte, reqErr error) bool {
if reqErr != nil || statusCode == 0 {
return false
}
if statusCode >= http.StatusInternalServerError || statusCode == http.StatusRequestTimeout {
return false
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityCreditsExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
if keyword == "resource has been exhausted" &&
statusCode == http.StatusTooManyRequests &&
classifyAntigravity429(body) == antigravity429Unknown &&
!antigravityHasQuotaResetDelayOrModelInfo(body) {
return false
}
return true
}
}
return false
}
func newAntigravityStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if statusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil {
err.retryAfter = retryAfter
}
}
return err
}
func (e *AntigravityExecutor) attemptCreditsFallback(
ctx context.Context,
auth *cliproxyauth.Auth,
httpClient *http.Client,
token string,
modelName string,
payload []byte,
stream bool,
alt string,
baseURL string,
originalBody []byte,
) (*http.Response, bool) {
if !antigravityCreditsRetryEnabled(e.cfg) {
return nil, false
}
if classifyAntigravity429(originalBody) != antigravity429QuotaExhausted {
return nil, false
}
now := time.Now()
if antigravityCreditsExhausted(auth, now) {
return nil, false
}
creditsPayload := injectEnabledCreditTypes(payload)
if len(creditsPayload) == 0 {
return nil, false
}
httpReq, errReq := e.buildRequest(ctx, auth, token, modelName, creditsPayload, stream, alt, baseURL)
if errReq != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errReq)
return nil, true
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, true
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
retryAfter, _ := parseRetryDelay(originalBody)
markAntigravityPreferCredits(auth, modelName, now, retryAfter)
clearAntigravityCreditsExhausted(auth)
return httpResp, true
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close credits fallback response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, true
}
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, modelName)
markAntigravityCreditsExhausted(auth, now)
}
return nil, true
}
// Execute performs a non-streaming request to the Antigravity API. // Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" { if opts.Alt == "responses/compact" {
@@ -203,8 +487,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
auth = updatedAuth auth = updatedAuth
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
@@ -222,8 +506,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
return resp, err return resp, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -237,7 +521,15 @@ attemptLoop:
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL) requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, false, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return resp, err return resp, err
@@ -245,7 +537,7 @@ attemptLoop:
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo return resp, errDo
} }
@@ -260,20 +552,50 @@ attemptLoop:
return resp, err return resp, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body) bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose) log.Errorf("antigravity executor: close response body error: %v", errClose)
} }
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead err = errRead
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, bodyBytes) helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, false, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
helps.RecordAPIResponseMetadata(ctx, e.cfg, creditsResp.StatusCode, creditsResp.Header.Clone())
creditsBody, errCreditsRead := io.ReadAll(creditsResp.Body)
if errClose := creditsResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close credits success response body error: %v", errClose)
}
if errCreditsRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errCreditsRead)
err = errCreditsRead
return resp, err
}
helps.AppendAPIResponseChunk(ctx, e.cfg, creditsBody)
reporter.Publish(ctx, helps.ParseAntigravityUsage(creditsBody))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, creditsBody, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: creditsResp.Header.Clone()}
reporter.EnsurePublished(ctx)
return resp, nil
}
}
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes)) log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
lastStatus = httpResp.StatusCode lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...) lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil lastErr = nil
@@ -281,6 +603,14 @@ attemptLoop:
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue continue
} }
if antigravityShouldRetryTransientResourceExhausted429(httpResp.StatusCode, bodyBytes) && attempt+1 < attempts {
delay := antigravityTransient429RetryDelay(attempt)
log.Debugf("antigravity executor: transient 429 resource exhausted for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return resp, errWait
}
continue attemptLoop
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) { if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) { if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
@@ -295,33 +625,21 @@ attemptLoop:
continue attemptLoop continue attemptLoop
} }
} }
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)} err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err return resp, err
} }
reporter.publish(ctx, parseAntigravityUsage(bodyBytes)) reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes))
var param any var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param) converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
return resp, nil return resp, nil
} }
switch { switch {
case lastStatus != 0: case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)} err = newAntigravityStatusErr(lastStatus, lastBody)
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil: case lastErr != nil:
err = lastErr err = lastErr
default: default:
@@ -345,8 +663,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
auth = updatedAuth auth = updatedAuth
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
@@ -364,8 +682,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
return resp, err return resp, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -379,7 +697,15 @@ attemptLoop:
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL) requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return resp, err return resp, err
@@ -387,7 +713,7 @@ attemptLoop:
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo return resp, errDo
} }
@@ -401,14 +727,14 @@ attemptLoop:
err = errDo err = errDo
return resp, err return resp, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body) bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose) log.Errorf("antigravity executor: close response body error: %v", errClose)
} }
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) { if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead err = errRead
return resp, err return resp, err
@@ -427,7 +753,24 @@ attemptLoop:
err = errRead err = errRead
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, bodyBytes) helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
httpResp = creditsResp
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
}
}
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
goto streamSuccessClaudeNonStream
}
lastStatus = httpResp.StatusCode lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...) lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil lastErr = nil
@@ -435,6 +778,14 @@ attemptLoop:
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue continue
} }
if antigravityShouldRetryTransientResourceExhausted429(httpResp.StatusCode, bodyBytes) && attempt+1 < attempts {
delay := antigravityTransient429RetryDelay(attempt)
log.Debugf("antigravity executor: transient 429 resource exhausted for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return resp, errWait
}
continue attemptLoop
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) { if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) { if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
@@ -449,16 +800,11 @@ attemptLoop:
continue attemptLoop continue attemptLoop
} }
} }
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)} err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err return resp, err
} }
streamSuccessClaudeNonStream:
out := make(chan cliproxyexecutor.StreamChunk) out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) { go func(resp *http.Response) {
defer close(out) defer close(out)
@@ -471,29 +817,29 @@ attemptLoop:
scanner.Buffer(nil, streamScannerBuffer) scanner.Buffer(nil, streamScannerBuffer)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models // Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk // Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line) line = helps.FilterSSEUsageMetadata(line)
payload := jsonPayload(line) payload := helps.JSONPayload(line)
if payload == nil { if payload == nil {
continue continue
} }
if detail, ok := parseAntigravityStreamUsage(payload); ok { if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
out <- cliproxyexecutor.StreamChunk{Payload: payload} out <- cliproxyexecutor.StreamChunk{Payload: payload}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else { } else {
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
} }
}(httpResp) }(httpResp)
@@ -509,24 +855,18 @@ attemptLoop:
} }
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())} resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
reporter.publish(ctx, parseAntigravityUsage(resp.Payload)) reporter.Publish(ctx, helps.ParseAntigravityUsage(resp.Payload))
var param any var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param) converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
return resp, nil return resp, nil
} }
switch { switch {
case lastStatus != 0: case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)} err = newAntigravityStatusErr(lastStatus, lastBody)
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil: case lastErr != nil:
err = lastErr err = lastErr
default: default:
@@ -748,8 +1088,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
auth = updatedAuth auth = updatedAuth
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
@@ -767,8 +1107,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
return nil, err return nil, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -782,14 +1122,22 @@ attemptLoop:
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL) requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return nil, err return nil, err
} }
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil, errDo return nil, errDo
} }
@@ -803,14 +1151,14 @@ attemptLoop:
err = errDo err = errDo
return nil, err return nil, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body) bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose) log.Errorf("antigravity executor: close response body error: %v", errClose)
} }
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) { if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead err = errRead
return nil, err return nil, err
@@ -829,7 +1177,24 @@ attemptLoop:
err = errRead err = errRead
return nil, err return nil, err
} }
appendAPIResponseChunk(ctx, e.cfg, bodyBytes) helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
httpResp = creditsResp
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
}
}
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
goto streamSuccessExecuteStream
}
lastStatus = httpResp.StatusCode lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...) lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil lastErr = nil
@@ -837,6 +1202,14 @@ attemptLoop:
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue continue
} }
if antigravityShouldRetryTransientResourceExhausted429(httpResp.StatusCode, bodyBytes) && attempt+1 < attempts {
delay := antigravityTransient429RetryDelay(attempt)
log.Debugf("antigravity executor: transient 429 resource exhausted for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return nil, errWait
}
continue attemptLoop
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) { if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) { if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
@@ -851,16 +1224,11 @@ attemptLoop:
continue attemptLoop continue attemptLoop
} }
} }
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)} err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return nil, err return nil, err
} }
streamSuccessExecuteStream:
out := make(chan cliproxyexecutor.StreamChunk) out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) { go func(resp *http.Response) {
defer close(out) defer close(out)
@@ -874,19 +1242,19 @@ attemptLoop:
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models // Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk // Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line) line = helps.FilterSSEUsageMetadata(line)
payload := jsonPayload(line) payload := helps.JSONPayload(line)
if payload == nil { if payload == nil {
continue continue
} }
if detail, ok := parseAntigravityStreamUsage(payload); ok { if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param)
@@ -899,11 +1267,11 @@ attemptLoop:
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]} out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else { } else {
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
} }
}(httpResp) }(httpResp)
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -911,13 +1279,7 @@ attemptLoop:
switch { switch {
case lastStatus != 0: case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)} err = newAntigravityStatusErr(lastStatus, lastBody)
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil: case lastErr != nil:
err = lastErr err = lastErr
default: default:
@@ -1011,8 +1373,13 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if host := resolveHost(base); host != "" { if host := resolveHost(base); host != "" {
httpReq.Host = host httpReq.Host = host
} }
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: requestURL.String(), URL: requestURL.String(),
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -1026,7 +1393,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return cliproxyexecutor.Response{}, errDo return cliproxyexecutor.Response{}, errDo
} }
@@ -1040,16 +1407,16 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
return cliproxyexecutor.Response{}, errDo return cliproxyexecutor.Response{}, errDo
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body) bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose) log.Errorf("antigravity executor: close response body error: %v", errClose)
} }
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead return cliproxyexecutor.Response{}, errRead
} }
appendAPIResponseChunk(ctx, e.cfg, bodyBytes) helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int() count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
@@ -1305,6 +1672,11 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
if host := resolveHost(base); host != "" { if host := resolveHost(base); host != "" {
httpReq.Host = host httpReq.Host = host
} }
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -1316,7 +1688,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
if e.cfg != nil && e.cfg.RequestLog { if e.cfg != nil && e.cfg.RequestLog {
payloadLog = []byte(payloadStr) payloadLog = []byte(payloadStr)
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: requestURL.String(), URL: requestURL.String(),
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -1420,7 +1792,7 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
} }
} }
} }
return defaultAntigravityAgent return misc.AntigravityUserAgent()
} }
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int { func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
@@ -1454,6 +1826,24 @@ func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
return strings.Contains(msg, "no capacity available") return strings.Contains(msg, "no capacity available")
} }
func antigravityShouldRetryTransientResourceExhausted429(statusCode int, body []byte) bool {
if statusCode != http.StatusTooManyRequests {
return false
}
if len(body) == 0 {
return false
}
if classifyAntigravity429(body) != antigravity429Unknown {
return false
}
status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String())
if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") {
return false
}
msg := strings.ToLower(string(body))
return strings.Contains(msg, "resource has been exhausted")
}
func antigravityNoCapacityRetryDelay(attempt int) time.Duration { func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
if attempt < 0 { if attempt < 0 {
attempt = 0 attempt = 0
@@ -1465,6 +1855,17 @@ func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
return delay return delay
} }
func antigravityTransient429RetryDelay(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := time.Duration(attempt+1) * 100 * time.Millisecond
if delay > 500*time.Millisecond {
delay = 500 * time.Millisecond
}
return delay
}
func antigravityWait(ctx context.Context, wait time.Duration) error { func antigravityWait(ctx context.Context, wait time.Duration) error {
if wait <= 0 { if wait <= 0 {
return nil return nil
@@ -1479,7 +1880,7 @@ func antigravityWait(ctx context.Context, wait time.Duration) error {
} }
} }
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string { var antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
if base := resolveCustomAntigravityBaseURL(auth); base != "" { if base := resolveCustomAntigravityBaseURL(auth); base != "" {
return []string{base} return []string{base}
} }

View File

@@ -0,0 +1,489 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
func resetAntigravityCreditsRetryState() {
antigravityCreditsExhaustedByAuth = sync.Map{}
antigravityPreferCreditsByModel = sync.Map{}
}
func TestClassifyAntigravity429(t *testing.T) {
t.Run("quota exhausted", func(t *testing.T) {
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
}
})
t.Run("structured rate limit", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
if got := classifyAntigravity429(body); got != antigravity429RateLimited {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429RateLimited)
}
})
t.Run("structured quota exhausted", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "QUOTA_EXHAUSTED"}
]
}
}`)
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
}
})
t.Run("unknown", func(t *testing.T) {
body := []byte(`{"error":{"message":"too many requests"}}`)
if got := classifyAntigravity429(body); got != antigravity429Unknown {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429Unknown)
}
})
}
func TestInjectEnabledCreditTypes(t *testing.T) {
body := []byte(`{"model":"gemini-2.5-flash","request":{}}`)
got := injectEnabledCreditTypes(body)
if got == nil {
t.Fatal("injectEnabledCreditTypes() returned nil")
}
if !strings.Contains(string(got), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("injectEnabledCreditTypes() = %s, want enabledCreditTypes", string(got))
}
if got := injectEnabledCreditTypes([]byte(`not json`)); got != nil {
t.Fatalf("injectEnabledCreditTypes() for invalid json = %s, want nil", string(got))
}
}
func TestShouldMarkAntigravityCreditsExhausted(t *testing.T) {
t.Run("credit errors are marked", func(t *testing.T) {
for _, body := range [][]byte{
[]byte(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`),
[]byte(`{"error":{"message":"minimumCreditAmountForUsage requirement not met"}}`),
} {
if !shouldMarkAntigravityCreditsExhausted(http.StatusForbidden, body, nil) {
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body))
}
}
})
t.Run("transient 429 resource exhausted is not marked", func(t *testing.T) {
body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}}`)
if shouldMarkAntigravityCreditsExhausted(http.StatusTooManyRequests, body, nil) {
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = true, want false", string(body))
}
})
t.Run("resource exhausted with quota metadata is still marked", func(t *testing.T) {
body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted","status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","metadata":{"quotaResetDelay":"1h","model":"claude-sonnet-4-6"}}]}}`)
if !shouldMarkAntigravityCreditsExhausted(http.StatusTooManyRequests, body, nil) {
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body))
}
})
if shouldMarkAntigravityCreditsExhausted(http.StatusServiceUnavailable, []byte(`{"error":{"message":"credits exhausted"}}`), nil) {
t.Fatal("shouldMarkAntigravityCreditsExhausted() = true for 5xx, want false")
}
}
func TestAntigravityExecute_RetriesTransient429ResourceExhausted(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var requestCount int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
switch requestCount {
case 1:
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}}`))
case 2:
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
default:
t.Fatalf("unexpected request count %d", requestCount)
}
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{RequestRetry: 1})
auth := &cliproxyauth.Auth{
ID: "auth-transient-429",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("Execute() returned empty payload")
}
if requestCount != 2 {
t.Fatalf("request count = %d, want 2", requestCount)
}
}
func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
requestBodies []string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
requestBodies = append(requestBodies, string(body))
reqNum := len(requestBodies)
mu.Unlock()
if reqNum == 1 {
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
return
}
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("second request body missing enabledCreditTypes: %s", string(body))
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-credits-ok",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("Execute() returned empty payload")
}
mu.Lock()
defer mu.Unlock()
if len(requestBodies) != 2 {
t.Fatalf("request count = %d, want 2", len(requestBodies))
}
}
func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var requestCount int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-credits-exhausted",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
markAntigravityCreditsExhausted(auth, time.Now())
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err == nil {
t.Fatal("Execute() error = nil, want 429")
}
sErr, ok := err.(statusErr)
if !ok {
t.Fatalf("Execute() error type = %T, want statusErr", err)
}
if got := sErr.StatusCode(); got != http.StatusTooManyRequests {
t.Fatalf("Execute() status code = %d, want %d", got, http.StatusTooManyRequests)
}
if requestCount != 1 {
t.Fatalf("request count = %d, want 1", requestCount)
}
}
func TestAntigravityExecute_PrefersCreditsAfterSuccessfulFallback(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
requestBodies []string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
requestBodies = append(requestBodies, string(body))
reqNum := len(requestBodies)
mu.Unlock()
switch reqNum {
case 1:
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"10s"}]}}`))
case 2, 3:
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("request %d body missing enabledCreditTypes: %s", reqNum, string(body))
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"OK"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
default:
t.Fatalf("unexpected request count %d", reqNum)
}
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-prefer-credits",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
request := cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}
opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatAntigravity}
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
t.Fatalf("first Execute() error = %v", err)
}
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
t.Fatalf("second Execute() error = %v", err)
}
mu.Lock()
defer mu.Unlock()
if len(requestBodies) != 3 {
t.Fatalf("request count = %d, want 3", len(requestBodies))
}
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("first request unexpectedly used credits: %s", requestBodies[0])
}
if !strings.Contains(requestBodies[1], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("fallback request missing credits: %s", requestBodies[1])
}
if !strings.Contains(requestBodies[2], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("preferred request missing credits: %s", requestBodies[2])
}
}
func TestAntigravityExecute_PreservesBaseURLFallbackAfterCreditsRetryFailure(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
firstCount int
secondCount int
)
firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
firstCount++
reqNum := firstCount
mu.Unlock()
switch reqNum {
case 1:
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"}]}}`))
case 2:
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("credits retry missing enabledCreditTypes: %s", string(body))
}
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":{"message":"permission denied"}}`))
default:
t.Fatalf("unexpected first server request count %d", reqNum)
}
}))
defer firstServer.Close()
secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
secondCount++
mu.Unlock()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
}))
defer secondServer.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-baseurl-fallback",
Attributes: map[string]string{
"base_url": firstServer.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
originalOrder := antigravityBaseURLFallbackOrder
defer func() { antigravityBaseURLFallbackOrder = originalOrder }()
antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
return []string{firstServer.URL, secondServer.URL}
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("Execute() returned empty payload")
}
if firstCount != 2 {
t.Fatalf("first server request count = %d, want 2", firstCount)
}
if secondCount != 1 {
t.Fatalf("second server request count = %d, want 1", secondCount)
}
}
func TestAntigravityExecute_DoesNotDirectInjectCreditsWhenFlagDisabled(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var requestBodies []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
requestBodies = append(requestBodies, string(body))
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: false},
})
auth := &cliproxyauth.Auth{
ID: "auth-flag-disabled",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
markAntigravityPreferCredits(auth, "gemini-2.5-flash", time.Now(), nil)
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err == nil {
t.Fatal("Execute() error = nil, want 429")
}
if len(requestBodies) != 1 {
t.Fatalf("request count = %d, want 1", len(requestBodies))
}
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("request unexpectedly used enabledCreditTypes with flag disabled: %s", requestBodies[0])
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,11 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"context" "context"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"regexp"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@@ -14,7 +16,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
xxHash64 "github.com/pierrec/xxHash/xxHash64"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -23,9 +28,7 @@ import (
) )
func resetClaudeDeviceProfileCache() { func resetClaudeDeviceProfileCache() {
claudeDeviceProfileCacheMu.Lock() helps.ResetClaudeDeviceProfileCache()
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu.Unlock()
} }
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request { func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
@@ -98,7 +101,7 @@ func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
req := newClaudeHeaderTestRequest(t, incoming) req := newClaudeHeaderTestRequest(t, incoming)
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg) applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64") assertClaudeFingerprint(t, req.Header, "evil-client/9.9", "9.9.9", "v24.5.0", "Linux", "x64")
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" { if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900") t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
} }
@@ -338,7 +341,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
var pauseOnce sync.Once var pauseOnce sync.Once
var releaseOnce sync.Once var releaseOnce sync.Once
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) { helps.ClaudeDeviceProfileBeforeCandidateStore = func(candidate helps.ClaudeDeviceProfile) {
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" { if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
return return
} }
@@ -346,13 +349,13 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
<-releaseLow <-releaseLow
} }
t.Cleanup(func() { t.Cleanup(func() {
claudeDeviceProfileBeforeCandidateStore = nil helps.ClaudeDeviceProfileBeforeCandidateStore = nil
releaseOnce.Do(func() { close(releaseLow) }) releaseOnce.Do(func() { close(releaseLow) })
}) })
lowResultCh := make(chan claudeDeviceProfile, 1) lowResultCh := make(chan helps.ClaudeDeviceProfile, 1)
go func() { go func() {
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{ lowResultCh <- helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"}, "User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"}, "X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"}, "X-Stainless-Runtime-Version": []string{"v24.3.0"},
@@ -367,7 +370,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
t.Fatal("timed out waiting for lower candidate to pause before storing") t.Fatal("timed out waiting for lower candidate to pause before storing")
} }
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{ highResult := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"}, "User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.75.0"}, "X-Stainless-Package-Version": []string{"0.75.0"},
"X-Stainless-Runtime-Version": []string{"v24.4.0"}, "X-Stainless-Runtime-Version": []string{"v24.4.0"},
@@ -398,7 +401,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64") t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
} }
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{ cached := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"curl/8.7.1"}, "User-Agent": []string{"curl/8.7.1"},
}, cfg) }, cfg)
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" { if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
@@ -564,7 +567,7 @@ func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *tes
}) })
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg) applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch()) assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
} }
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) { func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
@@ -591,14 +594,14 @@ func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallbac
}) })
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg) applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch()) assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
} }
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) { func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
if claudeDeviceProfileStabilizationEnabled(nil) { if helps.ClaudeDeviceProfileStabilizationEnabled(nil) {
t.Fatal("expected nil config to default to disabled stabilization") t.Fatal("expected nil config to default to disabled stabilization")
} }
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) { if helps.ClaudeDeviceProfileStabilizationEnabled(&config.Config{}) {
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization") t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
} }
} }
@@ -736,6 +739,35 @@ func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
} }
} }
func TestApplyClaudeToolPrefix_KnownFallbackBuiltinsRemainUnprefixed(t *testing.T) {
for _, builtin := range []string{"web_search", "code_execution", "text_editor", "computer"} {
t.Run(builtin, func(t *testing.T) {
input := []byte(fmt.Sprintf(`{
"tools":[{"name":"Read"}],
"tool_choice":{"type":"tool","name":%q},
"messages":[{"role":"assistant","content":[{"type":"tool_use","name":%q,"id":"toolu_1","input":{}},{"type":"tool_reference","tool_name":%q},{"type":"tool_result","tool_use_id":"toolu_1","content":[{"type":"tool_reference","tool_name":%q}]}]}]
}`, builtin, builtin, builtin, builtin))
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != builtin {
t.Fatalf("tool_choice.name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != builtin {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != builtin {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.2.content.0.tool_name").String(); got != builtin {
t.Fatalf("messages.0.content.2.content.0.tool_name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
})
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) { func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`) input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_") out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -796,8 +828,6 @@ func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
} }
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) { func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
resetUserIDCache()
var userIDs []string var userIDs []string
var requestModels []string var requestModels []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -857,15 +887,13 @@ func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
if userIDs[0] != userIDs[1] { if userIDs[0] != userIDs[1] {
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1]) t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
} }
if !isValidUserID(userIDs[0]) { if !helps.IsValidUserID(userIDs[0]) {
t.Fatalf("user_id %q is not valid", userIDs[0]) t.Fatalf("user_id %q is not valid", userIDs[0])
} }
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0]) t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
} }
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) { func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
resetUserIDCache()
var userIDs []string var userIDs []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body) body, _ := io.ReadAll(r.Body)
@@ -903,7 +931,7 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
if userIDs[0] == userIDs[1] { if userIDs[0] == userIDs[1] {
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0]) t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
} }
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) { if !helps.IsValidUserID(userIDs[0]) || !helps.IsValidUserID(userIDs[1]) {
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1]) t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
} }
} }
@@ -966,6 +994,28 @@ func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.
} }
} }
func TestNormalizeCacheControlTTL_PreservesKeyOrderWhenModified(t *testing.T) {
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
out := normalizeCacheControlTTL(payload)
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
}
outStr := string(out)
idxModel := strings.Index(outStr, `"model"`)
idxMessages := strings.Index(outStr, `"messages"`)
idxTools := strings.Index(outStr, `"tools"`)
idxSystem := strings.Index(outStr, `"system"`)
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
}
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) { func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{ payload := []byte(`{
"tools": [ "tools": [
@@ -995,6 +1045,31 @@ func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T)
} }
} }
func TestEnforceCacheControlLimit_PreservesKeyOrderWhenModified(t *testing.T) {
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}},{"name":"t2","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
}
outStr := string(out)
idxModel := strings.Index(outStr, `"model"`)
idxMessages := strings.Index(outStr, `"messages"`)
idxTools := strings.Index(outStr, `"tools"`)
idxSystem := strings.Index(outStr, `"system"`)
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
}
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) { func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
payload := []byte(`{ payload := []byte(`{
"tools": [ "tools": [
@@ -1183,6 +1258,83 @@ func testClaudeExecutorInvalidCompressedErrorBody(
} }
} }
func TestEnsureModelMaxTokens_UsesRegisteredMaxCompletionTokens(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-max-completion-tokens-client"
modelID := "test-claude-max-completion-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
MaxCompletionTokens: 4096,
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-max-completion-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 4096 {
t.Fatalf("max_tokens = %d, want %d", got, 4096)
}
}
func TestEnsureModelMaxTokens_DefaultsMissingValue(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-default-max-tokens-client"
modelID := "test-claude-default-max-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-default-max-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != defaultModelMaxTokens {
t.Fatalf("max_tokens = %d, want %d", got, defaultModelMaxTokens)
}
}
func TestEnsureModelMaxTokens_PreservesExplicitValue(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-preserve-max-tokens-client"
modelID := "test-claude-preserve-max-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
MaxCompletionTokens: 4096,
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-preserve-max-tokens-model","max_tokens":2048,"messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 2048 {
t.Fatalf("max_tokens = %d, want %d", got, 2048)
}
}
func TestEnsureModelMaxTokens_SkipsUnregisteredModel(t *testing.T) {
input := []byte(`{"model":"test-claude-unregistered-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, "test-claude-unregistered-model")
if gjson.GetBytes(out, "max_tokens").Exists() {
t.Fatalf("max_tokens should remain unset, got %s", gjson.GetBytes(out, "max_tokens").Raw)
}
}
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming // TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
// requests use Accept-Encoding: identity so the upstream cannot respond with a // requests use Accept-Encoding: identity so the upstream cannot respond with a
// compressed SSE body that would silently break the line scanner. // compressed SSE body that would silently break the line scanner.
@@ -1340,6 +1492,35 @@ func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
} }
} }
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
// detects zstd-compressed content via magic bytes even when Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
_, _ = enc.Write([]byte(plaintext))
_ = enc.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns // TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
// plain text untouched when Content-Encoding is absent and no magic bytes match. // plain text untouched when Content-Encoding is absent and no magic bytes match.
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) { func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
@@ -1411,77 +1592,6 @@ func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T)
} }
} }
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
// path's enforced identity encoding.
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
var gotEncoding string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
// Inject Accept-Encoding via the custom header attribute mechanism.
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
}
}
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
// Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
_, _ = enc.Write([]byte(plaintext))
_ = enc.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the // TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
// error path (4xx) correctly decompresses a gzip body even when the upstream omits // error path (4xx) correctly decompresses a gzip body even when the upstream omits
// the Content-Encoding header. This closes the gap left by PR #1771, which only // the Content-Encoding header. This closes the gap left by PR #1771, which only
@@ -1565,6 +1675,45 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
} }
} }
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies that the
// streaming executor enforces Accept-Encoding: identity regardless of auth.Attributes override.
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
var gotEncoding string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
}
}
// Test case 1: String system prompt is preserved and converted to a content block // Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) { func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`) payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
@@ -1648,3 +1797,155 @@ func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String()) t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
} }
} }
func TestClaudeExecutor_ExperimentalCCHSigningDisabledByDefaultKeepsLegacyHeader(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected request body to be captured")
}
billingHeader := gjson.GetBytes(seenBody, "system.0.text").String()
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
t.Fatalf("system.0.text = %q, want billing header", billingHeader)
}
if strings.Contains(billingHeader, "cch=00000;") {
t.Fatalf("legacy mode should not forward cch placeholder, got %q", billingHeader)
}
}
func TestClaudeExecutor_ExperimentalCCHSigningOptInSignsFinalBody(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{
ClaudeKey: []config.ClaudeKey{{
APIKey: "key-123",
BaseURL: server.URL,
ExperimentalCCHSigning: true,
}},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
const messageText = "please keep literal cch=00000 in this message"
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"please keep literal cch=00000 in this message"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected request body to be captured")
}
if got := gjson.GetBytes(seenBody, "messages.0.content.0.text").String(); got != messageText {
t.Fatalf("message text = %q, want %q", got, messageText)
}
billingPattern := regexp.MustCompile(`(x-anthropic-billing-header:[^"]*?\bcch=)([0-9a-f]{5})(;)`)
match := billingPattern.FindSubmatch(seenBody)
if match == nil {
t.Fatalf("expected signed billing header in body: %s", string(seenBody))
}
actualCCH := string(match[2])
unsignedBody := billingPattern.ReplaceAll(seenBody, []byte(`${1}00000${3}`))
wantCCH := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, 0x6E52736AC806831E)&0xFFFFF)
if actualCCH != wantCCH {
t.Fatalf("cch = %q, want %q\nbody: %s", actualCCH, wantCCH, string(seenBody))
}
}
func TestApplyCloaking_PreservesConfiguredStrictModeAndSensitiveWordsWhenModeOmitted(t *testing.T) {
cfg := &config.Config{
ClaudeKey: []config.ClaudeKey{{
APIKey: "key-123",
Cloak: &config.CloakConfig{
StrictMode: true,
SensitiveWords: []string{"proxy"},
},
}},
}
auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "key-123"}}
payload := []byte(`{"system":"proxy rules","messages":[{"role":"user","content":[{"type":"text","text":"proxy access"}]}]}`)
out := applyCloaking(context.Background(), cfg, auth, payload, "claude-3-5-sonnet-20241022", "key-123")
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("expected strict mode to keep only injected system blocks, got %d", len(blocks))
}
if got := gjson.GetBytes(out, "messages.0.content.0.text").String(); !strings.Contains(got, "\u200B") {
t.Fatalf("expected configured sensitive word obfuscation to apply, got %q", got)
}
}
func TestNormalizeClaudeTemperatureForThinking_AdaptiveCoercesToOne(t *testing.T) {
payload := []byte(`{"temperature":0,"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`)
out := normalizeClaudeTemperatureForThinking(payload)
if got := gjson.GetBytes(out, "temperature").Float(); got != 1 {
t.Fatalf("temperature = %v, want 1", got)
}
}
func TestNormalizeClaudeTemperatureForThinking_EnabledCoercesToOne(t *testing.T) {
payload := []byte(`{"temperature":0.2,"thinking":{"type":"enabled","budget_tokens":2048}}`)
out := normalizeClaudeTemperatureForThinking(payload)
if got := gjson.GetBytes(out, "temperature").Float(); got != 1 {
t.Fatalf("temperature = %v, want 1", got)
}
}
func TestNormalizeClaudeTemperatureForThinking_NoThinkingLeavesTemperatureAlone(t *testing.T) {
payload := []byte(`{"temperature":0,"messages":[{"role":"user","content":"hi"}]}`)
out := normalizeClaudeTemperatureForThinking(payload)
if got := gjson.GetBytes(out, "temperature").Float(); got != 0 {
t.Fatalf("temperature = %v, want 0", got)
}
}
func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOriginalTemperature(t *testing.T) {
payload := []byte(`{"temperature":0,"thinking":{"type":"adaptive"},"output_config":{"effort":"max"},"tool_choice":{"type":"any"}}`)
out := disableThinkingIfToolChoiceForced(payload)
out = normalizeClaudeTemperatureForThinking(out)
if gjson.GetBytes(out, "thinking").Exists() {
t.Fatalf("thinking should be removed when tool_choice forces tool use")
}
if got := gjson.GetBytes(out, "temperature").Float(); got != 0 {
t.Fatalf("temperature = %v, want 0", got)
}
}

View File

@@ -0,0 +1,81 @@
package executor
import (
"fmt"
"regexp"
"strings"
xxHash64 "github.com/pierrec/xxHash/xxHash64"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const claudeCCHSeed uint64 = 0x6E52736AC806831E
var claudeBillingHeaderCCHPattern = regexp.MustCompile(`\bcch=([0-9a-f]{5});`)
func signAnthropicMessagesBody(body []byte) []byte {
billingHeader := gjson.GetBytes(body, "system.0.text").String()
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
return body
}
if !claudeBillingHeaderCCHPattern.MatchString(billingHeader) {
return body
}
unsignedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(billingHeader, "cch=00000;")
unsignedBody, err := sjson.SetBytes(body, "system.0.text", unsignedBillingHeader)
if err != nil {
return body
}
cch := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, claudeCCHSeed)&0xFFFFF)
signedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(unsignedBillingHeader, "cch="+cch+";")
signedBody, err := sjson.SetBytes(unsignedBody, "system.0.text", signedBillingHeader)
if err != nil {
return unsignedBody
}
return signedBody
}
func resolveClaudeKeyConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.ClaudeKey {
if cfg == nil || auth == nil {
return nil
}
apiKey, baseURL := claudeCreds(auth)
if apiKey == "" {
return nil
}
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if !strings.EqualFold(cfgKey, apiKey) {
continue
}
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
continue
}
return entry
}
return nil
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
entry := resolveClaudeKeyConfig(cfg, auth)
if entry == nil {
return nil
}
return entry.Cloak
}
func experimentalCCHSigningEnabled(cfg *config.Config, auth *cliproxyauth.Auth) bool {
entry := resolveClaudeKeyConfig(cfg, auth)
return entry != nil && entry.ExperimentalCCHSigning
}

View File

@@ -4,9 +4,11 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
@@ -14,8 +16,11 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
const ( const (
@@ -98,10 +103,12 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
if len(opts.OriginalRequest) > 0 { if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest originalPayloadSource = opts.OriginalRequest
} }
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false) originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, _ = sjson.SetBytes(translated, "stream", true)
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil { if err != nil {
@@ -114,6 +121,8 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return resp, err return resp, err
} }
e.applyHeaders(httpReq, accessToken, userID, domain) e.applyHeaders(httpReq, accessToken, userID, domain)
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -160,11 +169,16 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, body) appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body)) aggregatedBody, usageDetail, err := aggregateOpenAIChatCompletionStream(body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
reporter.publish(ctx, usageDetail)
reporter.ensurePublished(ctx) reporter.ensurePublished(ctx)
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, aggregatedBody, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -341,3 +355,197 @@ func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID,
req.Header.Set("X-IDE-Version", "2.63.2") req.Header.Set("X-IDE-Version", "2.63.2")
req.Header.Set("X-Requested-With", "XMLHttpRequest") req.Header.Set("X-Requested-With", "XMLHttpRequest")
} }
type openAIChatStreamChoiceAccumulator struct {
Role string
ContentParts []string
ReasoningParts []string
FinishReason string
ToolCalls map[int]*openAIChatStreamToolCallAccumulator
ToolCallOrder []int
NativeFinishReason any
}
type openAIChatStreamToolCallAccumulator struct {
ID string
Type string
Name string
Arguments strings.Builder
}
func aggregateOpenAIChatCompletionStream(raw []byte) ([]byte, usage.Detail, error) {
lines := bytes.Split(raw, []byte("\n"))
var (
responseID string
model string
created int64
serviceTier string
systemFP string
usageDetail usage.Detail
choices = map[int]*openAIChatStreamChoiceAccumulator{}
choiceOrder []int
)
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(line[5:])
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
if !gjson.ValidBytes(payload) {
continue
}
root := gjson.ParseBytes(payload)
if responseID == "" {
responseID = root.Get("id").String()
}
if model == "" {
model = root.Get("model").String()
}
if created == 0 {
created = root.Get("created").Int()
}
if serviceTier == "" {
serviceTier = root.Get("service_tier").String()
}
if systemFP == "" {
systemFP = root.Get("system_fingerprint").String()
}
if detail, ok := parseOpenAIStreamUsage(line); ok {
usageDetail = detail
}
for _, choiceResult := range root.Get("choices").Array() {
idx := int(choiceResult.Get("index").Int())
choice := choices[idx]
if choice == nil {
choice = &openAIChatStreamChoiceAccumulator{ToolCalls: map[int]*openAIChatStreamToolCallAccumulator{}}
choices[idx] = choice
choiceOrder = append(choiceOrder, idx)
}
delta := choiceResult.Get("delta")
if role := delta.Get("role").String(); role != "" {
choice.Role = role
}
if content := delta.Get("content").String(); content != "" {
choice.ContentParts = append(choice.ContentParts, content)
}
if reasoning := delta.Get("reasoning_content").String(); reasoning != "" {
choice.ReasoningParts = append(choice.ReasoningParts, reasoning)
}
if finishReason := choiceResult.Get("finish_reason").String(); finishReason != "" {
choice.FinishReason = finishReason
}
if nativeFinishReason := choiceResult.Get("native_finish_reason"); nativeFinishReason.Exists() {
choice.NativeFinishReason = nativeFinishReason.Value()
}
for _, toolCallResult := range delta.Get("tool_calls").Array() {
toolIdx := int(toolCallResult.Get("index").Int())
toolCall := choice.ToolCalls[toolIdx]
if toolCall == nil {
toolCall = &openAIChatStreamToolCallAccumulator{}
choice.ToolCalls[toolIdx] = toolCall
choice.ToolCallOrder = append(choice.ToolCallOrder, toolIdx)
}
if id := toolCallResult.Get("id").String(); id != "" {
toolCall.ID = id
}
if typ := toolCallResult.Get("type").String(); typ != "" {
toolCall.Type = typ
}
if name := toolCallResult.Get("function.name").String(); name != "" {
toolCall.Name = name
}
if args := toolCallResult.Get("function.arguments").String(); args != "" {
toolCall.Arguments.WriteString(args)
}
}
}
}
if responseID == "" && model == "" && len(choiceOrder) == 0 {
return nil, usageDetail, fmt.Errorf("codebuddy: streaming response did not contain any chat completion chunks")
}
response := map[string]any{
"id": responseID,
"object": "chat.completion",
"created": created,
"model": model,
"choices": make([]map[string]any, 0, len(choiceOrder)),
"usage": map[string]any{
"prompt_tokens": usageDetail.InputTokens,
"completion_tokens": usageDetail.OutputTokens,
"total_tokens": usageDetail.TotalTokens,
},
}
if serviceTier != "" {
response["service_tier"] = serviceTier
}
if systemFP != "" {
response["system_fingerprint"] = systemFP
}
for _, idx := range choiceOrder {
choice := choices[idx]
message := map[string]any{
"role": choice.Role,
"content": strings.Join(choice.ContentParts, ""),
}
if message["role"] == "" {
message["role"] = "assistant"
}
if len(choice.ReasoningParts) > 0 {
message["reasoning_content"] = strings.Join(choice.ReasoningParts, "")
}
if len(choice.ToolCallOrder) > 0 {
toolCalls := make([]map[string]any, 0, len(choice.ToolCallOrder))
for _, toolIdx := range choice.ToolCallOrder {
toolCall := choice.ToolCalls[toolIdx]
toolCallType := toolCall.Type
if toolCallType == "" {
toolCallType = "function"
}
arguments := toolCall.Arguments.String()
if arguments == "" {
arguments = "{}"
}
toolCalls = append(toolCalls, map[string]any{
"id": toolCall.ID,
"type": toolCallType,
"function": map[string]any{
"name": toolCall.Name,
"arguments": arguments,
},
})
}
message["tool_calls"] = toolCalls
}
finishReason := choice.FinishReason
if finishReason == "" {
finishReason = "stop"
}
choicePayload := map[string]any{
"index": idx,
"message": message,
"finish_reason": finishReason,
}
if choice.NativeFinishReason != nil {
choicePayload["native_finish_reason"] = choice.NativeFinishReason
}
response["choices"] = append(response["choices"].([]map[string]any), choicePayload)
}
out, err := json.Marshal(response)
if err != nil {
return nil, usageDetail, fmt.Errorf("codebuddy: failed to encode aggregated response: %w", err)
}
return out, usageDetail, nil
}

View File

@@ -7,12 +7,14 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"sort"
"strings" "strings"
"time" "time"
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -28,8 +30,8 @@ import (
) )
const ( const (
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464" codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
codexOriginator = "codex_cli_rs" codexOriginator = "codex-tui"
) )
var dataTag = []byte("data:") var dataTag = []byte("data:")
@@ -73,7 +75,7 @@ func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
if err := e.PrepareRequest(httpReq, auth); err != nil { if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err return nil, err
} }
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq) return httpClient.Do(httpReq)
} }
@@ -88,8 +90,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
baseURL = "https://chatgpt.com/backend-api/codex" baseURL = "https://chatgpt.com/backend-api/codex"
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
@@ -106,16 +108,15 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
return resp, err return resp, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true) body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention") body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier") body, _ = sjson.DeleteBytes(body, "safety_identifier")
if !gjson.GetBytes(body, "instructions").Exists() { body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "instructions", "") body = normalizeCodexInstructions(body)
}
url := strings.TrimSuffix(baseURL, "/") + "/responses" url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body) httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -129,7 +130,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -140,10 +141,10 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
defer func() { defer func() {
@@ -151,38 +152,79 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
log.Errorf("codex executor: close response body error: %v", errClose) log.Errorf("codex executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = newCodexStatusErr(httpResp.StatusCode, b) err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err return resp, err
} }
data, err := io.ReadAll(httpResp.Body) data, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
lines := bytes.Split(data, []byte("\n")) lines := bytes.Split(data, []byte("\n"))
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for _, line := range lines { for _, line := range lines {
if !bytes.HasPrefix(line, dataTag) { if !bytes.HasPrefix(line, dataTag) {
continue continue
} }
line = bytes.TrimSpace(line[5:]) eventData := bytes.TrimSpace(line[5:])
if gjson.GetBytes(line, "type").String() != "response.completed" { eventType := gjson.GetBytes(eventData, "type").String()
if eventType == "response.output_item.done" {
itemResult := gjson.GetBytes(eventData, "item")
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
continue
}
outputIndexResult := gjson.GetBytes(eventData, "output_index")
if outputIndexResult.Exists() {
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
} else {
outputItemsFallback = append(outputItemsFallback, []byte(itemResult.Raw))
}
continue continue
} }
if detail, ok := parseCodexUsage(line); ok { if eventType != "response.completed" {
reporter.publish(ctx, detail) continue
}
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
completedData := eventData
outputResult := gjson.GetBytes(completedData, "response.output")
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
if shouldPatchOutput {
completedDataPatched := completedData
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output", []byte(`[]`))
indexes := make([]int64, 0, len(outputItemsByIndex))
for idx := range outputItemsByIndex {
indexes = append(indexes, idx)
}
sort.Slice(indexes, func(i, j int) bool {
return indexes[i] < indexes[j]
})
for _, idx := range indexes {
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", outputItemsByIndex[idx])
}
for _, item := range outputItemsFallback {
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", item)
}
completedData = completedDataPatched
} }
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, completedData, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -198,8 +240,8 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
baseURL = "https://chatgpt.com/backend-api/codex" baseURL = "https://chatgpt.com/backend-api/codex"
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai-response") to := sdktranslator.FromString("openai-response")
@@ -216,10 +258,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
return resp, err return resp, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "stream") body, _ = sjson.DeleteBytes(body, "stream")
body = normalizeCodexInstructions(body)
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact" url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
httpReq, err := e.cacheHelper(ctx, from, url, req, body) httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -233,7 +276,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -244,10 +287,10 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
defer func() { defer func() {
@@ -255,22 +298,22 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
log.Errorf("codex executor: close response body error: %v", errClose) log.Errorf("codex executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = newCodexStatusErr(httpResp.StatusCode, b) err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err return resp, err
} }
data, err := io.ReadAll(httpResp.Body) data, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data)) reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -288,8 +331,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
baseURL = "https://chatgpt.com/backend-api/codex" baseURL = "https://chatgpt.com/backend-api/codex"
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
@@ -306,15 +349,14 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err return nil, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention") body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier") body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
if !gjson.GetBytes(body, "instructions").Exists() { body = normalizeCodexInstructions(body)
body, _ = sjson.SetBytes(body, "instructions", "")
}
url := strings.TrimSuffix(baseURL, "/") + "/responses" url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body) httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -328,7 +370,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -340,24 +382,24 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err return nil, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, readErr := io.ReadAll(httpResp.Body) data, readErr := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose) log.Errorf("codex executor: close response body error: %v", errClose)
} }
if readErr != nil { if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr) helps.RecordAPIResponseError(ctx, e.cfg, readErr)
return nil, readErr return nil, readErr
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = newCodexStatusErr(httpResp.StatusCode, data) err = newCodexStatusErr(httpResp.StatusCode, data)
return nil, err return nil, err
} }
@@ -374,13 +416,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if bytes.HasPrefix(line, dataTag) { if bytes.HasPrefix(line, dataTag) {
data := bytes.TrimSpace(line[5:]) data := bytes.TrimSpace(line[5:])
if gjson.GetBytes(data, "type").String() == "response.completed" { if gjson.GetBytes(data, "type").String() == "response.completed" {
if detail, ok := parseCodexUsage(data); ok { if detail, ok := helps.ParseCodexUsage(data); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
} }
} }
@@ -391,8 +433,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
} }
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} }
}() }()
@@ -415,10 +457,9 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention") body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier") body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "stream", false) body, _ = sjson.SetBytes(body, "stream", false)
if !gjson.GetBytes(body, "instructions").Exists() { body = normalizeCodexInstructions(body)
body, _ = sjson.SetBytes(body, "instructions", "")
}
enc, err := tokenizerForCodexModel(baseModel) enc, err := tokenizerForCodexModel(baseModel)
if err != nil { if err != nil {
@@ -597,18 +638,18 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
} }
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) { func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
var cache codexCache var cache helps.CodexCache
if from == "claude" { if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id") userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() { if userIDResult.Exists() {
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String()) key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
var ok bool var ok bool
if cache, ok = getCodexCache(key); !ok { if cache, ok = helps.GetCodexCache(key); !ok {
cache = codexCache{ cache = helps.CodexCache{
ID: uuid.New().String(), ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour), Expire: time.Now().Add(1 * time.Hour),
} }
setCodexCache(key, cache) helps.SetCodexCache(key, cache)
} }
} }
} else if from == "openai-response" { } else if from == "openai-response" {
@@ -617,7 +658,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
cache.ID = promptCacheKey.String() cache.ID = promptCacheKey.String()
} }
} else if from == "openai" { } else if from == "openai" {
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" { if apiKey := strings.TrimSpace(helps.APIKeyFromContext(ctx)); apiKey != "" {
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String() cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
} }
} }
@@ -630,7 +671,6 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
return nil, err return nil, err
} }
if cache.ID != "" { if cache.ID != "" {
httpReq.Header.Set("Conversation_id", cache.ID)
httpReq.Header.Set("Session_id", cache.ID) httpReq.Header.Set("Session_id", cache.ID)
} }
return httpReq, nil return httpReq, nil
@@ -645,13 +685,19 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
ginHeaders = ginCtx.Request.Header ginHeaders = ginCtx.Request.Header
} }
if ginHeaders.Get("X-Codex-Beta-Features") != "" {
r.Header.Set("X-Codex-Beta-Features", ginHeaders.Get("X-Codex-Beta-Features"))
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", "") misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "") misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "") misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth) cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent) ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if strings.Contains(r.Header.Get("User-Agent"), "Mac OS") {
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
}
if stream { if stream {
r.Header.Set("Accept", "text/event-stream") r.Header.Set("Accept", "text/event-stream")
} else { } else {
@@ -685,13 +731,47 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
} }
func newCodexStatusErr(statusCode int, body []byte) statusErr { func newCodexStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)} errCode := statusCode
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil { if isCodexModelCapacityError(body) {
errCode = http.StatusTooManyRequests
}
err := statusErr{code: errCode, msg: string(body)}
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
err.retryAfter = retryAfter err.retryAfter = retryAfter
} }
return err return err
} }
func normalizeCodexInstructions(body []byte) []byte {
instructions := gjson.GetBytes(body, "instructions")
if !instructions.Exists() || instructions.Type == gjson.Null {
body, _ = sjson.SetBytes(body, "instructions", "")
}
return body
}
func isCodexModelCapacityError(errorBody []byte) bool {
if len(errorBody) == 0 {
return false
}
candidates := []string{
gjson.GetBytes(errorBody, "error.message").String(),
gjson.GetBytes(errorBody, "message").String(),
string(errorBody),
}
for _, candidate := range candidates {
lower := strings.ToLower(strings.TrimSpace(candidate))
if lower == "" {
continue
}
if strings.Contains(lower, "selected model is at capacity") ||
strings.Contains(lower, "model is at capacity. please try a different model") {
return true
}
}
return false
}
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration { func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 { if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
return nil return nil

View File

@@ -42,8 +42,8 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
if gotKey != expectedKey { if gotKey != expectedKey {
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey) t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
} }
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey { if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != "" {
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey) t.Fatalf("Conversation_id = %q, want empty", gotConversation)
} }
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey { if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey) t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)

View File

@@ -0,0 +1,79 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorCompactAddsDefaultInstructions(t *testing.T) {
cases := []struct {
name string
payload string
}{
{
name: "missing instructions",
payload: `{"model":"gpt-5.4","input":"hello"}`,
},
{
name: "null instructions",
payload: `{"model":"gpt-5.4","instructions":null,"input":"hello"}`,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}`))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(tc.payload),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Alt: "responses/compact",
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/responses/compact" {
t.Fatalf("path = %q, want %q", gotPath, "/responses/compact")
}
if !gjson.GetBytes(gotBody, "instructions").Exists() {
t.Fatalf("expected instructions in compact request body, got %s", string(gotBody))
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
if string(resp.Payload) != `{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}` {
t.Fatalf("payload = %s", string(resp.Payload))
}
})
}
}

View File

@@ -0,0 +1,123 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorExecuteNormalizesNullInstructions(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/responses" {
t.Fatalf("path = %q, want %q", gotPath, "/responses")
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
}
func TestCodexExecutorExecuteStreamNormalizesNullInstructions(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Stream: true,
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for range result.Chunks {
}
if gotPath != "/responses" {
t.Fatalf("path = %q, want %q", gotPath, "/responses")
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
}
func TestCodexExecutorCountTokensTreatsNullInstructionsAsEmpty(t *testing.T) {
executor := NewCodexExecutor(&config.Config{})
nullResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
})
if err != nil {
t.Fatalf("CountTokens(null) error: %v", err)
}
emptyResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":"","input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
})
if err != nil {
t.Fatalf("CountTokens(empty) error: %v", err)
}
if string(nullResp.Payload) != string(emptyResp.Payload) {
t.Fatalf("token count payload mismatch:\nnull=%s\nempty=%s", string(nullResp.Payload), string(emptyResp.Payload))
}
}

View File

@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
}) })
} }
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
err := newCodexStatusErr(http.StatusBadRequest, body)
if got := err.StatusCode(); got != http.StatusTooManyRequests {
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
}
if err.RetryAfter() != nil {
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
}
}
func itoa(v int64) string { func itoa(v int64) string {
return strconv.FormatInt(v, 10) return strconv.FormatInt(v, 10)
} }

View File

@@ -0,0 +1,46 @@
package executor
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1775555723,\"status\":\"completed\",\"model\":\"gpt-5.4-mini-2026-03-17\",\"output\":[],\"usage\":{\"input_tokens\":8,\"output_tokens\":28,\"total_tokens\":36}}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4-mini",
Payload: []byte(`{"model":"gpt-5.4-mini","messages":[{"role":"user","content":"Say ok"}]}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
gotContent := gjson.GetBytes(resp.Payload, "choices.0.message.content").String()
if gotContent != "ok" {
t.Fatalf("choices.0.message.content = %q, want %q; payload=%s", gotContent, "ok", string(resp.Payload))
}
}

View File

@@ -15,10 +15,12 @@ import (
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -44,10 +46,18 @@ const (
type CodexWebsocketsExecutor struct { type CodexWebsocketsExecutor struct {
*CodexExecutor *CodexExecutor
sessMu sync.Mutex store *codexWebsocketSessionStore
}
type codexWebsocketSessionStore struct {
mu sync.Mutex
sessions map[string]*codexWebsocketSession sessions map[string]*codexWebsocketSession
} }
var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{
sessions: make(map[string]*codexWebsocketSession),
}
type codexWebsocketSession struct { type codexWebsocketSession struct {
sessionID string sessionID string
@@ -71,7 +81,7 @@ type codexWebsocketSession struct {
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor { func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
return &CodexWebsocketsExecutor{ return &CodexWebsocketsExecutor{
CodexExecutor: NewCodexExecutor(cfg), CodexExecutor: NewCodexExecutor(cfg),
sessions: make(map[string]*codexWebsocketSession), store: globalCodexWebsocketSessionStore,
} }
} }
@@ -155,8 +165,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
baseURL = "https://chatgpt.com/backend-api/codex" baseURL = "https://chatgpt.com/backend-api/codex"
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
@@ -173,8 +183,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
return resp, err return resp, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true) body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
@@ -209,7 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
} }
wsReqBody := buildCodexWebsocketRequestBody(body) wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ wsReqLog := helps.UpstreamRequestLog{
URL: wsURL, URL: wsURL,
Method: "WEBSOCKET", Method: "WEBSOCKET",
Headers: wsHeaders.Clone(), Headers: wsHeaders.Clone(),
@@ -219,16 +229,14 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
AuthLabel: authLabel, AuthLabel: authLabel,
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) }
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if respHS != nil {
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil { if errDial != nil {
bodyErr := websocketHandshakeBody(respHS) bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 { if respHS != nil {
appendAPIResponseChunk(ctx, e.cfg, bodyErr) helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
} }
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.Execute(ctx, auth, req, opts) return e.CodexExecutor.Execute(ctx, auth, req, opts)
@@ -236,10 +244,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if respHS != nil && respHS.StatusCode > 0 { if respHS != nil && respHS.StatusCode > 0 {
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
} }
recordAPIResponseError(ctx, e.cfg, errDial) helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
return resp, errDial return resp, errDial
} }
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error") recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
if sess == nil { if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL) logCodexWebsocketConnected(executionSessionID, authID, wsURL)
defer func() { defer func() {
@@ -268,10 +276,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
// Retry once with a fresh websocket connection. This is mainly to handle // Retry once with a fresh websocket connection. This is mainly to handle
// upstream closing the socket between sequential requests within the same // upstream closing the socket between sequential requests within the same
// execution session. // execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry == nil && connRetry != nil { if errDialRetry == nil && connRetry != nil {
wsReqBodyRetry := buildCodexWebsocketRequestBody(body) wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL, URL: wsURL,
Method: "WEBSOCKET", Method: "WEBSOCKET",
Headers: wsHeaders.Clone(), Headers: wsHeaders.Clone(),
@@ -282,20 +290,22 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) })
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil { if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
conn = connRetry conn = connRetry
wsReqBody = wsReqBodyRetry wsReqBody = wsReqBodyRetry
} else { } else {
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
recordAPIResponseError(ctx, e.cfg, errSendRetry) helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
return resp, errSendRetry return resp, errSendRetry
} }
} else { } else {
recordAPIResponseError(ctx, e.cfg, errDialRetry) closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
return resp, errDialRetry return resp, errDialRetry
} }
} else { } else {
recordAPIResponseError(ctx, e.cfg, errSend) helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
return resp, errSend return resp, errSend
} }
} }
@@ -306,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
} }
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh) msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
return resp, errRead return resp, errRead
} }
if msgType != websocket.TextMessage { if msgType != websocket.TextMessage {
@@ -315,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
} }
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
return resp, err return resp, err
} }
continue continue
@@ -325,21 +335,21 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if len(payload) == 0 { if len(payload) == 0 {
continue continue
} }
appendAPIResponseChunk(ctx, e.cfg, payload) helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok { if wsErr, ok := parseCodexWebsocketError(payload); ok {
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
} }
recordAPIResponseError(ctx, e.cfg, wsErr) helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
return resp, wsErr return resp, wsErr
} }
payload = normalizeCodexWebsocketCompletion(payload) payload = normalizeCodexWebsocketCompletion(payload)
eventType := gjson.GetBytes(payload, "type").String() eventType := gjson.GetBytes(payload, "type").String()
if eventType == "response.completed" { if eventType == "response.completed" {
if detail, ok := parseCodexUsage(payload); ok { if detail, ok := helps.ParseCodexUsage(payload); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, &param)
@@ -364,8 +374,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
baseURL = "https://chatgpt.com/backend-api/codex" baseURL = "https://chatgpt.com/backend-api/codex"
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
@@ -376,8 +386,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
return nil, err return nil, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses" httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
wsURL, err := buildCodexResponsesWebsocketURL(httpURL) wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
@@ -403,7 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
} }
wsReqBody := buildCodexWebsocketRequestBody(body) wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ wsReqLog := helps.UpstreamRequestLog{
URL: wsURL, URL: wsURL,
Method: "WEBSOCKET", Method: "WEBSOCKET",
Headers: wsHeaders.Clone(), Headers: wsHeaders.Clone(),
@@ -413,18 +423,18 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthLabel: authLabel, AuthLabel: authLabel,
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) }
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
var upstreamHeaders http.Header var upstreamHeaders http.Header
if respHS != nil { if respHS != nil {
upstreamHeaders = respHS.Header.Clone() upstreamHeaders = respHS.Header.Clone()
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
} }
if errDial != nil { if errDial != nil {
bodyErr := websocketHandshakeBody(respHS) bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 { if respHS != nil {
appendAPIResponseChunk(ctx, e.cfg, bodyErr) helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
} }
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts) return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
@@ -432,13 +442,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if respHS != nil && respHS.StatusCode > 0 { if respHS != nil && respHS.StatusCode > 0 {
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
} }
recordAPIResponseError(ctx, e.cfg, errDial) helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
if sess != nil { if sess != nil {
sess.reqMu.Unlock() sess.reqMu.Unlock()
} }
return nil, errDial return nil, errDial
} }
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error") recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
if sess == nil { if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL) logCodexWebsocketConnected(executionSessionID, authID, wsURL)
@@ -451,20 +461,21 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
} }
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil { if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
recordAPIResponseError(ctx, e.cfg, errSend) helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "send_error", errSend) e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
// Retry once with a new websocket connection for the same execution session. // Retry once with a new websocket connection for the same execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry != nil || connRetry == nil { if errDialRetry != nil || connRetry == nil {
recordAPIResponseError(ctx, e.cfg, errDialRetry) closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
sess.clearActive(readCh) sess.clearActive(readCh)
sess.reqMu.Unlock() sess.reqMu.Unlock()
return nil, errDialRetry return nil, errDialRetry
} }
wsReqBodyRetry := buildCodexWebsocketRequestBody(body) wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL, URL: wsURL,
Method: "WEBSOCKET", Method: "WEBSOCKET",
Headers: wsHeaders.Clone(), Headers: wsHeaders.Clone(),
@@ -475,8 +486,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) })
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil { if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
recordAPIResponseError(ctx, e.cfg, errSendRetry) helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
sess.clearActive(readCh) sess.clearActive(readCh)
sess.reqMu.Unlock() sess.reqMu.Unlock()
@@ -542,8 +554,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
} }
terminateReason = "read_error" terminateReason = "read_error"
terminateErr = errRead terminateErr = errRead
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
_ = send(cliproxyexecutor.StreamChunk{Err: errRead}) _ = send(cliproxyexecutor.StreamChunk{Err: errRead})
return return
} }
@@ -552,8 +564,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
err = fmt.Errorf("codex websockets executor: unexpected binary message") err = fmt.Errorf("codex websockets executor: unexpected binary message")
terminateReason = "unexpected_binary" terminateReason = "unexpected_binary"
terminateErr = err terminateErr = err
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
} }
@@ -567,13 +579,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if len(payload) == 0 { if len(payload) == 0 {
continue continue
} }
appendAPIResponseChunk(ctx, e.cfg, payload) helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok { if wsErr, ok := parseCodexWebsocketError(payload); ok {
terminateReason = "upstream_error" terminateReason = "upstream_error"
terminateErr = wsErr terminateErr = wsErr
recordAPIResponseError(ctx, e.cfg, wsErr) helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
if sess != nil { if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
} }
@@ -584,8 +596,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
payload = normalizeCodexWebsocketCompletion(payload) payload = normalizeCodexWebsocketCompletion(payload)
eventType := gjson.GetBytes(payload, "type").String() eventType := gjson.GetBytes(payload, "type").String()
if eventType == "response.completed" || eventType == "response.done" { if eventType == "response.completed" || eventType == "response.done" {
if detail, ok := parseCodexUsage(payload); ok { if detail, ok := helps.ParseCodexUsage(payload); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
} }
@@ -722,7 +734,7 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
} }
switch setting.URL.Scheme { switch setting.URL.Scheme {
case "socks5": case "socks5", "socks5h":
var proxyAuth *proxy.Auth var proxyAuth *proxy.Auth
if setting.URL.User != nil { if setting.URL.User != nil {
username := setting.URL.User.Username() username := setting.URL.User.Username()
@@ -767,19 +779,19 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
return rawJSON, headers return rawJSON, headers
} }
var cache codexCache var cache helps.CodexCache
if from == "claude" { if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id") userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() { if userIDResult.Exists() {
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String()) key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
if cached, ok := getCodexCache(key); ok { if cached, ok := helps.GetCodexCache(key); ok {
cache = cached cache = cached
} else { } else {
cache = codexCache{ cache = helps.CodexCache{
ID: uuid.New().String(), ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour), Expire: time.Now().Add(1 * time.Hour),
} }
setCodexCache(key, cache) helps.SetCodexCache(key, cache)
} }
} }
} else if from == "openai-response" { } else if from == "openai-response" {
@@ -791,7 +803,6 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
if cache.ID != "" { if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID) rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
headers.Set("Conversation_id", cache.ID) headers.Set("Conversation_id", cache.ID)
headers.Set("Session_id", cache.ID)
} }
return rawJSON, headers return rawJSON, headers
@@ -806,11 +817,11 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
} }
var ginHeaders http.Header var ginHeaders http.Header
if ginCtx := ginContextFrom(ctx); ginCtx != nil && ginCtx.Request != nil { if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header ginHeaders = ginCtx.Request.Header.Clone()
} }
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth) _, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "") ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "") misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "") misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
@@ -826,8 +837,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
betaHeader = codexResponsesWebsocketBetaHeaderValue betaHeader = codexResponsesWebsocketBetaHeaderValue
} }
headers.Set("OpenAI-Beta", betaHeader) headers.Set("OpenAI-Beta", betaHeader)
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString()) if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent) misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
}
headers.Del("User-Agent")
isAPIKey := false isAPIKey := false
if auth != nil && auth.Attributes != nil { if auth != nil && auth.Attributes != nil {
@@ -1011,6 +1024,32 @@ func encodeCodexWebsocketAsSSE(payload []byte) []byte {
return line return line
} }
func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog {
upgradeInfo := info
upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL)
upgradeInfo.Method = http.MethodGet
upgradeInfo.Body = nil
upgradeInfo.Headers = info.Headers.Clone()
if upgradeInfo.Headers == nil {
upgradeInfo.Headers = make(http.Header)
}
if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" {
upgradeInfo.Headers.Set("Connection", "Upgrade")
}
if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" {
upgradeInfo.Headers.Set("Upgrade", "websocket")
}
return upgradeInfo
}
func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) {
if resp == nil {
return
}
helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone())
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
}
func websocketHandshakeBody(resp *http.Response) []byte { func websocketHandshakeBody(resp *http.Response) []byte {
if resp == nil || resp.Body == nil { if resp == nil || resp.Body == nil {
return nil return nil
@@ -1055,16 +1094,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
if sessionID == "" { if sessionID == "" {
return nil return nil
} }
e.sessMu.Lock() if e == nil {
defer e.sessMu.Unlock() return nil
if e.sessions == nil {
e.sessions = make(map[string]*codexWebsocketSession)
} }
if sess, ok := e.sessions[sessionID]; ok && sess != nil { store := e.store
if store == nil {
store = globalCodexWebsocketSessionStore
}
store.mu.Lock()
defer store.mu.Unlock()
if store.sessions == nil {
store.sessions = make(map[string]*codexWebsocketSession)
}
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
return sess return sess
} }
sess := &codexWebsocketSession{sessionID: sessionID} sess := &codexWebsocketSession{sessionID: sessionID}
e.sessions[sessionID] = sess store.sessions[sessionID] = sess
return sess return sess
} }
@@ -1210,14 +1256,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
return return
} }
if sessionID == cliproxyauth.CloseAllExecutionSessionsID { if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
e.closeAllExecutionSessions("executor_replaced") // Executor replacement can happen during hot reload (config/credential changes).
// Do not force-close upstream websocket sessions here, otherwise in-flight
// downstream websocket requests get interrupted.
return return
} }
e.sessMu.Lock() store := e.store
sess := e.sessions[sessionID] if store == nil {
delete(e.sessions, sessionID) store = globalCodexWebsocketSessionStore
e.sessMu.Unlock() }
store.mu.Lock()
sess := store.sessions[sessionID]
delete(store.sessions, sessionID)
store.mu.Unlock()
e.closeExecutionSession(sess, "session_closed") e.closeExecutionSession(sess, "session_closed")
} }
@@ -1227,15 +1279,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
return return
} }
e.sessMu.Lock() store := e.store
sessions := make([]*codexWebsocketSession, 0, len(e.sessions)) if store == nil {
for sessionID, sess := range e.sessions { store = globalCodexWebsocketSessionStore
delete(e.sessions, sessionID) }
store.mu.Lock()
sessions := make([]*codexWebsocketSession, 0, len(store.sessions))
for sessionID, sess := range store.sessions {
delete(store.sessions, sessionID)
if sess != nil { if sess != nil {
sessions = append(sessions, sess) sessions = append(sessions, sess)
} }
} }
e.sessMu.Unlock() store.mu.Unlock()
for i := range sessions { for i := range sessions {
e.closeExecutionSession(sessions[i], reason) e.closeExecutionSession(sessions[i], reason)
@@ -1243,6 +1299,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
} }
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) { func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
closeCodexWebsocketSession(sess, reason)
}
func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) {
if sess == nil { if sess == nil {
return return
} }
@@ -1283,6 +1343,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason)) log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
} }
// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions
// associated with the supplied auth ID.
func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) {
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
reason = strings.TrimSpace(reason)
if reason == "" {
reason = "auth_removed"
}
store := globalCodexWebsocketSessionStore
if store == nil {
return
}
type sessionItem struct {
sessionID string
sess *codexWebsocketSession
}
store.mu.Lock()
items := make([]sessionItem, 0, len(store.sessions))
for sessionID, sess := range store.sessions {
items = append(items, sessionItem{sessionID: sessionID, sess: sess})
}
store.mu.Unlock()
matches := make([]sessionItem, 0)
for i := range items {
sess := items[i].sess
if sess == nil {
continue
}
sess.connMu.Lock()
sessAuthID := strings.TrimSpace(sess.authID)
sess.connMu.Unlock()
if sessAuthID == authID {
matches = append(matches, items[i])
}
}
if len(matches) == 0 {
return
}
toClose := make([]*codexWebsocketSession, 0, len(matches))
store.mu.Lock()
for i := range matches {
current, ok := store.sessions[matches[i].sessionID]
if !ok || current == nil || current != matches[i].sess {
continue
}
delete(store.sessions, matches[i].sessionID)
toClose = append(toClose, current)
}
store.mu.Unlock()
for i := range toClose {
closeCodexWebsocketSession(toClose[i], reason)
}
}
// CodexAutoExecutor routes Codex requests to the websocket transport only when: // CodexAutoExecutor routes Codex requests to the websocket transport only when:
// 1. The downstream transport is websocket, and // 1. The downstream transport is websocket, and
// 2. The selected auth enables websockets. // 2. The selected auth enables websockets.

View File

@@ -0,0 +1,48 @@
package executor
import (
"testing"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) {
sessionID := "test-session-store-survives-replace"
globalCodexWebsocketSessionStore.mu.Lock()
delete(globalCodexWebsocketSessionStore.sessions, sessionID)
globalCodexWebsocketSessionStore.mu.Unlock()
exec1 := NewCodexWebsocketsExecutor(nil)
sess1 := exec1.getOrCreateSession(sessionID)
if sess1 == nil {
t.Fatalf("expected session to be created")
}
exec2 := NewCodexWebsocketsExecutor(nil)
sess2 := exec2.getOrCreateSession(sessionID)
if sess2 == nil {
t.Fatalf("expected session to be available across executors")
}
if sess1 != sess2 {
t.Fatalf("expected the same session instance across executors")
}
exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID)
globalCodexWebsocketSessionStore.mu.Lock()
_, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID]
globalCodexWebsocketSessionStore.mu.Unlock()
if !stillPresent {
t.Fatalf("expected session to remain after executor replacement close marker")
}
exec2.CloseExecutionSession(sessionID)
globalCodexWebsocketSessionStore.mu.Lock()
_, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID]
globalCodexWebsocketSessionStore.mu.Unlock()
if presentAfterClose {
t.Fatalf("expected session to be removed after explicit close")
}
}

View File

@@ -38,8 +38,8 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue { if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
} }
if got := headers.Get("User-Agent"); got != codexUserAgent { if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent) t.Fatalf("User-Agent = %s, want empty", got)
} }
if got := headers.Get("Version"); got != "" { if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got) t.Fatalf("Version = %q, want empty", got)
@@ -97,8 +97,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg) headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" { if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0") t.Fatalf("User-Agent = %s, want empty", got)
} }
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" { if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b") t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
@@ -129,8 +129,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg) got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" { if gotVal := got.Get("User-Agent"); gotVal != "" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua") t.Fatalf("User-Agent = %s, want empty", gotVal)
} }
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" { if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta") t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
@@ -155,8 +155,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg) headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "config-ua" { if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua") t.Fatalf("User-Agent = %s, want empty", got)
} }
if got := headers.Get("x-codex-beta-features"); got != "client-beta" { if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta") t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
@@ -177,8 +177,8 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg) headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
if got := headers.Get("User-Agent"); got != codexUserAgent { if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent) t.Fatalf("User-Agent = %s, want empty", got)
} }
if got := headers.Get("x-codex-beta-features"); got != "" { if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got) t.Fatalf("x-codex-beta-features = %q, want empty", got)

View File

@@ -0,0 +1,129 @@
package executor
import (
"context"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
}
func parseOpenAIUsage(data []byte) usage.Detail {
return helps.ParseOpenAIUsage(data)
}
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
return helps.ParseOpenAIStreamUsage(line)
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
return helps.ParseOpenAIUsage(data)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
return helps.ParseOpenAIStreamUsage(line)
}
func getTokenizer(model string) (tokenizer.Codec, error) {
return helps.TokenizerForModel(model)
}
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return helps.CountOpenAIChatTokens(enc, payload)
}
func countClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return helps.CountClaudeChatTokens(enc, payload)
}
func buildOpenAIUsageJSON(count int64) []byte {
return helps.BuildOpenAIUsageJSON(count)
}
type upstreamRequestLog = helps.UpstreamRequestLog
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
helps.RecordAPIRequest(ctx, cfg, info)
}
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
helps.RecordAPIResponseMetadata(ctx, cfg, status, headers)
}
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
helps.RecordAPIResponseError(ctx, cfg, err)
}
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
helps.AppendAPIResponseChunk(ctx, cfg, chunk)
}
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
return helps.PayloadRequestedModel(opts, fallback)
}
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
return helps.ApplyPayloadConfigWithRoot(cfg, model, protocol, root, payload, original, requestedModel)
}
func summarizeErrorBody(contentType string, body []byte) string {
return helps.SummarizeErrorBody(contentType, body)
}
func apiKeyFromContext(ctx context.Context) string {
return helps.APIKeyFromContext(ctx)
}
func tokenizerForModel(model string) (tokenizer.Codec, error) {
return helps.TokenizerForModel(model)
}
func collectOpenAIContent(content gjson.Result, segments *[]string) {
helps.CollectOpenAIContent(content, segments)
}
type usageReporter struct {
reporter *helps.UsageReporter
}
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
return &usageReporter{reporter: helps.NewUsageReporter(ctx, provider, model, auth)}
}
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
if r == nil || r.reporter == nil {
return
}
r.reporter.Publish(ctx, detail)
}
func (r *usageReporter) publishFailure(ctx context.Context) {
if r == nil || r.reporter == nil {
return
}
r.reporter.PublishFailure(ctx)
}
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
if r == nil || r.reporter == nil {
return
}
r.reporter.TrackFailure(ctx, errPtr)
}
func (r *usageReporter) ensurePublished(ctx context.Context) {
if r == nil || r.reporter == nil {
return
}
r.reporter.EnsurePublished(ctx)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -81,6 +82,11 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
} }
req.Header.Set("Authorization", "Bearer "+tok.AccessToken) req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(req, "unknown") applyGeminiCLIHeaders(req, "unknown")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil return nil
} }
@@ -112,8 +118,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return resp, err return resp, err
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli") to := sdktranslator.FromString("gemini-cli")
@@ -132,8 +138,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
} }
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload) basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel) basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -190,7 +196,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, attemptModel) applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "application/json") reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(), Headers: reqHTTP.Header.Clone(),
@@ -204,7 +211,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
httpResp, errDo := httpClient.Do(reqHTTP) httpResp, errDo := httpClient.Do(reqHTTP)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
err = errDo err = errDo
return resp, err return resp, err
} }
@@ -213,15 +220,15 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose) log.Errorf("gemini cli executor: close response body error: %v", errClose)
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead err = errRead
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 { if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
reporter.publish(ctx, parseGeminiCLIUsage(data)) reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
var param any var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param) out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -230,7 +237,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
lastStatus = httpResp.StatusCode lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...) lastBody = append([]byte(nil), data...)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 { if httpResp.StatusCode == 429 {
if idx+1 < len(models) { if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
@@ -245,7 +252,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
} }
if len(lastBody) > 0 { if len(lastBody) > 0 {
appendAPIResponseChunk(ctx, e.cfg, lastBody) helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
} }
if lastStatus == 0 { if lastStatus == 0 {
lastStatus = 429 lastStatus = 429
@@ -266,8 +273,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
return nil, err return nil, err
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli") to := sdktranslator.FromString("gemini-cli")
@@ -286,8 +293,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
} }
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload) basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel) basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
projectID := resolveGeminiProjectID(auth) projectID := resolveGeminiProjectID(auth)
@@ -335,7 +342,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, attemptModel) applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "text/event-stream") reqHTTP.Header.Set("Accept", "text/event-stream")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(), Headers: reqHTTP.Header.Clone(),
@@ -349,25 +357,25 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
httpResp, errDo := httpClient.Do(reqHTTP) httpResp, errDo := httpClient.Do(reqHTTP)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
err = errDo err = errDo
return nil, err return nil, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body) data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose) log.Errorf("gemini cli executor: close response body error: %v", errClose)
} }
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead err = errRead
return nil, err return nil, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
lastStatus = httpResp.StatusCode lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...) lastBody = append([]byte(nil), data...)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 { if httpResp.StatusCode == 429 {
if idx+1 < len(models) { if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
@@ -394,9 +402,9 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiCLIStreamUsage(line); ok { if detail, ok := helps.ParseGeminiCLIStreamUsage(line); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
if bytes.HasPrefix(line, dataTag) { if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param)
@@ -411,8 +419,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]} out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} }
return return
@@ -420,13 +428,13 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
data, errRead := io.ReadAll(resp.Body) data, errRead := io.ReadAll(resp.Body)
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errRead} out <- cliproxyexecutor.StreamChunk{Err: errRead}
return return
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data)) reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
var param any var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param)
for i := range segments { for i := range segments {
@@ -443,7 +451,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
} }
if len(lastBody) > 0 { if len(lastBody) > 0 {
appendAPIResponseChunk(ctx, e.cfg, lastBody) helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
} }
if lastStatus == 0 { if lastStatus == 0 {
lastStatus = 429 lastStatus = 429
@@ -516,7 +524,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, baseModel) applyGeminiCLIHeaders(reqHTTP, baseModel)
reqHTTP.Header.Set("Accept", "application/json") reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(), Headers: reqHTTP.Header.Clone(),
@@ -530,17 +539,19 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
resp, errDo := httpClient.Do(reqHTTP) resp, errDo := httpClient.Do(reqHTTP)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo return cliproxyexecutor.Response{}, errDo
} }
data, errRead := io.ReadAll(resp.Body) data, errRead := io.ReadAll(resp.Body)
_ = resp.Body.Close() if errClose := resp.Body.Close(); errClose != nil {
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone()) helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead return cliproxyexecutor.Response{}, errRead
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode >= 200 && resp.StatusCode < 300 { if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int() count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
@@ -611,7 +622,7 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
} }
ctxToken := ctx ctxToken := ctx
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil { if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient) ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
} }
@@ -707,7 +718,7 @@ func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any {
} }
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout) return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
} }
func cloneMap(in map[string]any) map[string]any { func cloneMap(in map[string]any) map[string]any {

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -85,7 +86,7 @@ func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
if err := e.PrepareRequest(httpReq, auth); err != nil { if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err return nil, err
} }
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq) return httpClient.Do(httpReq)
} }
@@ -110,8 +111,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
apiKey, bearer := geminiCreds(auth) apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
// Official Gemini API via API key or OAuth bearer // Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat from := opts.SourceFormat
@@ -130,8 +131,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} }
body = fixGeminiImageAspectRatio(baseModel, body) body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
action := "generateContent" action := "generateContent"
@@ -165,7 +166,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -177,10 +178,10 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
defer func() { defer func() {
@@ -188,21 +189,21 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
log.Errorf("gemini executor: close response body error: %v", errClose) log.Errorf("gemini executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
data, err := io.ReadAll(httpResp.Body) data, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data)) reporter.Publish(ctx, helps.ParseGeminiUsage(data))
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -218,8 +219,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
apiKey, bearer := geminiCreds(auth) apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
@@ -237,8 +238,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
} }
body = fixGeminiImageAspectRatio(baseModel, body) body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
@@ -268,7 +269,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -280,17 +281,17 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err return nil, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini executor: close response body error: %v", errClose) log.Errorf("gemini executor: close response body error: %v", errClose)
} }
@@ -310,14 +311,14 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
filtered := FilterSSEUsageMetadata(line) filtered := helps.FilterSSEUsageMetadata(line)
payload := jsonPayload(filtered) payload := helps.JSONPayload(filtered)
if len(payload) == 0 { if len(payload) == 0 {
continue continue
} }
if detail, ok := parseGeminiStreamUsage(payload); ok { if detail, ok := helps.ParseGeminiStreamUsage(payload); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param)
for i := range lines { for i := range lines {
@@ -329,8 +330,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} }
}() }()
@@ -381,7 +382,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -393,23 +394,27 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
resp, err := httpClient.Do(httpReq) resp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err return cliproxyexecutor.Response{}, err
} }
defer func() { _ = resp.Body.Close() }() defer func() {
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone()) if errClose := resp.Body.Close(); errClose != nil {
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
data, err := io.ReadAll(resp.Body) data, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err return cliproxyexecutor.Response{}, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, helps.SummarizeErrorBody(resp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)} return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
} }

View File

@@ -16,7 +16,9 @@ import (
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex" vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -227,7 +229,7 @@ func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
if err := e.PrepareRequest(httpReq, auth); err != nil { if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err return nil, err
} }
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq) return httpClient.Do(httpReq)
} }
@@ -301,8 +303,8 @@ func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Aut
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) { func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
var body []byte var body []byte
@@ -332,8 +334,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
} }
body = fixGeminiImageAspectRatio(baseModel, body) body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
} }
@@ -362,6 +364,11 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
return resp, statusErr{code: 500, msg: "internal server error"} return resp, statusErr{code: 500, msg: "internal server error"}
} }
applyGeminiHeaders(httpReq, auth) applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -369,7 +376,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -381,10 +388,10 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo return resp, errDo
} }
defer func() { defer func() {
@@ -392,21 +399,21 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
log.Errorf("vertex executor: close response body error: %v", errClose) log.Errorf("vertex executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
data, errRead := io.ReadAll(httpResp.Body) data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead return resp, errRead
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data)) reporter.Publish(ctx, helps.ParseGeminiUsage(data))
// For Imagen models, convert response to Gemini format before translation // For Imagen models, convert response to Gemini format before translation
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview // This ensures Imagen responses use the same format as gemini-3-pro-image-preview
@@ -427,8 +434,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) { func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
@@ -447,8 +454,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
} }
body = fixGeminiImageAspectRatio(baseModel, body) body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, false) action := getVertexAction(baseModel, false)
@@ -477,6 +484,11 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
httpReq.Header.Set("x-goog-api-key", apiKey) httpReq.Header.Set("x-goog-api-key", apiKey)
} }
applyGeminiHeaders(httpReq, auth) applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -484,7 +496,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -496,10 +508,10 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo return resp, errDo
} }
defer func() { defer func() {
@@ -507,21 +519,21 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
log.Errorf("vertex executor: close response body error: %v", errClose) log.Errorf("vertex executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
data, errRead := io.ReadAll(httpResp.Body) data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead return resp, errRead
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data)) reporter.Publish(ctx, helps.ParseGeminiUsage(data))
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -532,8 +544,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) { func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
@@ -552,8 +564,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
} }
body = fixGeminiImageAspectRatio(baseModel, body) body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true) action := getVertexAction(baseModel, true)
@@ -581,6 +593,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
return nil, statusErr{code: 500, msg: "internal server error"} return nil, statusErr{code: 500, msg: "internal server error"}
} }
applyGeminiHeaders(httpReq, auth) applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -588,7 +605,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -600,17 +617,17 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo return nil, errDo
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose) log.Errorf("vertex executor: close response body error: %v", errClose)
} }
@@ -630,9 +647,9 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok { if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines { for i := range lines {
@@ -644,8 +661,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} }
}() }()
@@ -656,8 +673,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) { func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
@@ -676,8 +693,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
} }
body = fixGeminiImageAspectRatio(baseModel, body) body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true) action := getVertexAction(baseModel, true)
@@ -705,6 +722,11 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
httpReq.Header.Set("x-goog-api-key", apiKey) httpReq.Header.Set("x-goog-api-key", apiKey)
} }
applyGeminiHeaders(httpReq, auth) applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -712,7 +734,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -724,17 +746,17 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo return nil, errDo
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose) log.Errorf("vertex executor: close response body error: %v", errClose)
} }
@@ -754,9 +776,9 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok { if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines { for i := range lines {
@@ -768,8 +790,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} }
}() }()
@@ -812,6 +834,11 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
} }
applyGeminiHeaders(httpReq, auth) applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -819,7 +846,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -831,10 +858,10 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo return cliproxyexecutor.Response{}, errDo
} }
defer func() { defer func() {
@@ -842,19 +869,19 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
log.Errorf("vertex executor: close response body error: %v", errClose) log.Errorf("vertex executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
} }
data, errRead := io.ReadAll(httpResp.Body) data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead return cliproxyexecutor.Response{}, errRead
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int() count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
@@ -896,6 +923,11 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
httpReq.Header.Set("x-goog-api-key", apiKey) httpReq.Header.Set("x-goog-api-key", apiKey)
} }
applyGeminiHeaders(httpReq, auth) applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -903,7 +935,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -915,10 +947,10 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq) httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil { if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo) helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo return cliproxyexecutor.Response{}, errDo
} }
defer func() { defer func() {
@@ -926,19 +958,19 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
log.Errorf("vertex executor: close response body error: %v", errClose) log.Errorf("vertex executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
} }
data, errRead := io.ReadAll(httpResp.Body) data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil { if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead) helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead return cliproxyexecutor.Response{}, errRead
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int() count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
@@ -1012,7 +1044,7 @@ func vertexBaseURL(location string) string {
} }
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) { func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil { if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
} }
// Use cloud-platform scope for Vertex AI. // Use cloud-platform scope for Vertex AI.

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -17,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -40,7 +42,7 @@ const (
copilotEditorVersion = "vscode/1.107.0" copilotEditorVersion = "vscode/1.107.0"
copilotPluginVersion = "copilot-chat/0.35.0" copilotPluginVersion = "copilot-chat/0.35.0"
copilotIntegrationID = "vscode-chat" copilotIntegrationID = "vscode-chat"
copilotOpenAIIntent = "conversation-panel" copilotOpenAIIntent = "conversation-edits"
copilotGitHubAPIVer = "2025-04-01" copilotGitHubAPIVer = "2025-04-01"
) )
@@ -126,6 +128,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = e.normalizeModel(req.Model, body) body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body) body = flattenAssistantContent(body)
body = stripUnsupportedBetas(body)
// Detect vision content before input normalization removes messages // Detect vision content before input normalization removes messages
hasVision := detectVisionContent(body) hasVision := detectVisionContent(body)
@@ -142,6 +145,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
if useResponses { if useResponses {
body = normalizeGitHubCopilotResponsesInput(body) body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body) body = normalizeGitHubCopilotResponsesTools(body)
body = applyGitHubCopilotResponsesDefaults(body)
} else { } else {
body = normalizeGitHubCopilotChatTools(body) body = normalizeGitHubCopilotChatTools(body)
} }
@@ -225,9 +229,10 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
if useResponses && from.String() == "claude" { if useResponses && from.String() == "claude" {
converted = translateGitHubCopilotResponsesNonStreamToClaude(data) converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
} else { } else {
data = normalizeGitHubCopilotReasoningField(data)
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param) converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
} }
resp = cliproxyexecutor.Response{Payload: converted} resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx) reporter.ensurePublished(ctx)
return resp, nil return resp, nil
} }
@@ -256,6 +261,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = e.normalizeModel(req.Model, body) body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body) body = flattenAssistantContent(body)
body = stripUnsupportedBetas(body)
// Detect vision content before input normalization removes messages // Detect vision content before input normalization removes messages
hasVision := detectVisionContent(body) hasVision := detectVisionContent(body)
@@ -272,6 +278,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
if useResponses { if useResponses {
body = normalizeGitHubCopilotResponsesInput(body) body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body) body = normalizeGitHubCopilotResponsesTools(body)
body = applyGitHubCopilotResponsesDefaults(body)
} else { } else {
body = normalizeGitHubCopilotChatTools(body) body = normalizeGitHubCopilotChatTools(body)
} }
@@ -378,7 +385,20 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
if useResponses && from.String() == "claude" { if useResponses && from.String() == "claude" {
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param) chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param)
} else { } else {
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param) // Strip SSE "data: " prefix before reasoning field normalization,
// since normalizeGitHubCopilotReasoningField expects pure JSON.
// Re-wrap with the prefix afterward for the translator.
normalizedLine := bytes.Clone(line)
if bytes.HasPrefix(line, dataTag) {
sseData := bytes.TrimSpace(line[len(dataTag):])
if !bytes.Equal(sseData, []byte("[DONE]")) && gjson.ValidBytes(sseData) {
normalized := normalizeGitHubCopilotReasoningField(bytes.Clone(sseData))
if !bytes.Equal(normalized, sseData) {
normalizedLine = append(append([]byte(nil), dataTag...), normalized...)
}
}
}
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, normalizedLine, &param)
} }
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
@@ -400,9 +420,28 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}, nil }, nil
} }
// CountTokens is not supported for GitHub Copilot. // CountTokens estimates token count locally using tiktoken, since the GitHub
func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { // Copilot API does not expose a dedicated token counting endpoint.
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"} func (e *GitHubCopilotExecutor) CountTokens(ctx context.Context, _ *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
enc, err := helps.TokenizerForModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("github copilot executor: tokenizer init failed: %w", err)
}
count, err := helps.CountOpenAIChatTokens(enc, translated)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("github copilot executor: token counting failed: %w", err)
}
usageJSON := helps.BuildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
} }
// Refresh validates the GitHub token is still working. // Refresh validates the GitHub token is still working.
@@ -491,46 +530,127 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
r.Header.Set("X-Request-Id", uuid.NewString()) r.Header.Set("X-Request-Id", uuid.NewString())
initiator := "user" initiator := "user"
if role := detectLastConversationRole(body); role == "assistant" || role == "tool" { if isAgentInitiated(body) {
initiator = "agent" initiator = "agent"
} }
r.Header.Set("X-Initiator", initiator) r.Header.Set("X-Initiator", initiator)
} }
func detectLastConversationRole(body []byte) string { // isAgentInitiated determines whether the current request is agent-initiated
// (tool callbacks, continuations) rather than user-initiated (new user prompt).
//
// GitHub Copilot uses the X-Initiator header for billing:
// - "user" → consumes premium request quota
// - "agent" → free (tool loops, continuations)
//
// The challenge: Claude Code sends tool results as role:"user" messages with
// content type "tool_result". After translation to OpenAI format, the tool_result
// part becomes a separate role:"tool" message, but if the original Claude message
// also contained text content (e.g. skill invocations, attachment descriptions),
// a role:"user" message is emitted AFTER the tool message, making the last message
// appear user-initiated when it's actually part of an agent tool loop.
//
// VSCode Copilot Chat solves this with explicit flags (iterationNumber,
// isContinuation, subAgentInvocationId). Since CPA doesn't have these flags,
// we infer agent status by checking whether the conversation contains prior
// assistant/tool messages — if it does, the current request is a continuation.
//
// References:
// - opencode#8030, opencode#15824: same root cause and fix approach
// - vscode-copilot-chat: toolCallingLoop.ts (iterationNumber === 0)
// - pi-ai: github-copilot-headers.ts (last message role check)
func isAgentInitiated(body []byte) bool {
if len(body) == 0 { if len(body) == 0 {
return "" return false
} }
// Chat Completions API: check messages array
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() { if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
arr := messages.Array() arr := messages.Array()
if len(arr) == 0 {
return false
}
lastRole := ""
for i := len(arr) - 1; i >= 0; i-- { for i := len(arr) - 1; i >= 0; i-- {
if role := arr[i].Get("role").String(); role != "" { if r := arr[i].Get("role").String(); r != "" {
return role lastRole = r
break
} }
} }
// If last message is assistant or tool, clearly agent-initiated.
if lastRole == "assistant" || lastRole == "tool" {
return true
}
// If last message is "user", check whether it contains tool results
// (indicating a tool-loop continuation) or if the preceding message
// is an assistant tool_use. This is more precise than checking for
// any prior assistant message, which would false-positive on genuine
// multi-turn follow-ups.
if lastRole == "user" {
// Check if the last user message contains tool_result content
lastContent := arr[len(arr)-1].Get("content")
if lastContent.Exists() && lastContent.IsArray() {
for _, part := range lastContent.Array() {
if part.Get("type").String() == "tool_result" {
return true
}
}
}
// Check if the second-to-last message is an assistant with tool_use
if len(arr) >= 2 {
prev := arr[len(arr)-2]
if prev.Get("role").String() == "assistant" {
prevContent := prev.Get("content")
if prevContent.Exists() && prevContent.IsArray() {
for _, part := range prevContent.Array() {
if part.Get("type").String() == "tool_use" {
return true
}
}
}
}
}
}
return false
} }
// Responses API: check input array
if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() { if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() {
arr := inputs.Array() arr := inputs.Array()
for i := len(arr) - 1; i >= 0; i-- { if len(arr) == 0 {
item := arr[i] return false
}
// Most Responses input items carry a top-level role. // Check last item
if role := item.Get("role").String(); role != "" { last := arr[len(arr)-1]
return role if role := last.Get("role").String(); role == "assistant" {
return true
}
switch last.Get("type").String() {
case "function_call", "function_call_arguments", "computer_call":
return true
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
return true
}
// If last item is user-role, check for prior non-user items
for _, item := range arr {
if role := item.Get("role").String(); role == "assistant" {
return true
} }
switch item.Get("type").String() { switch item.Get("type").String() {
case "function_call", "function_call_arguments", "computer_call": case "function_call", "function_call_output", "function_call_response",
return "assistant" "function_call_arguments", "computer_call", "computer_call_output":
case "function_call_output", "function_call_response", "tool_result", "computer_call_output": return true
return "tool"
} }
} }
} }
return "" return false
} }
// detectVisionContent checks if the request body contains vision/image content. // detectVisionContent checks if the request body contains vision/image content.
@@ -572,6 +692,85 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte
return body return body
} }
// copilotUnsupportedBetas lists beta headers that are Anthropic-specific and
// must not be forwarded to GitHub Copilot. The context-1m beta enables 1M
// context on Anthropic's API, but Copilot's Claude models are limited to
// ~128K-200K. Passing it through would not enable 1M on Copilot, but stripping
// it from the translated body avoids confusing downstream translators.
var copilotUnsupportedBetas = []string{
"context-1m-2025-08-07",
}
// stripUnsupportedBetas removes Anthropic-specific beta entries from the
// translated request body. In OpenAI format the betas may appear under
// "metadata.betas" or a top-level "betas" array; in Claude format they sit at
// "betas". This function checks all known locations.
func stripUnsupportedBetas(body []byte) []byte {
betaPaths := []string{"betas", "metadata.betas"}
for _, path := range betaPaths {
arr := gjson.GetBytes(body, path)
if !arr.Exists() || !arr.IsArray() {
continue
}
var filtered []string
changed := false
for _, item := range arr.Array() {
beta := item.String()
if isCopilotUnsupportedBeta(beta) {
changed = true
continue
}
filtered = append(filtered, beta)
}
if !changed {
continue
}
if len(filtered) == 0 {
body, _ = sjson.DeleteBytes(body, path)
} else {
body, _ = sjson.SetBytes(body, path, filtered)
}
}
return body
}
func isCopilotUnsupportedBeta(beta string) bool {
return slices.Contains(copilotUnsupportedBetas, beta)
}
// normalizeGitHubCopilotReasoningField maps Copilot's non-standard
// 'reasoning_text' field to the standard OpenAI 'reasoning_content' field
// that the SDK translator expects. This handles both streaming deltas
// (choices[].delta.reasoning_text) and non-streaming messages
// (choices[].message.reasoning_text). The field is only renamed when
// 'reasoning_content' is absent or null, preserving standard responses.
// All choices are processed to support n>1 requests.
func normalizeGitHubCopilotReasoningField(data []byte) []byte {
choices := gjson.GetBytes(data, "choices")
if !choices.Exists() || !choices.IsArray() {
return data
}
for i := range choices.Array() {
// Non-streaming: choices[i].message.reasoning_text
msgRT := fmt.Sprintf("choices.%d.message.reasoning_text", i)
msgRC := fmt.Sprintf("choices.%d.message.reasoning_content", i)
if rt := gjson.GetBytes(data, msgRT); rt.Exists() && rt.String() != "" {
if rc := gjson.GetBytes(data, msgRC); !rc.Exists() || rc.Type == gjson.Null || rc.String() == "" {
data, _ = sjson.SetBytes(data, msgRC, rt.String())
}
}
// Streaming: choices[i].delta.reasoning_text
deltaRT := fmt.Sprintf("choices.%d.delta.reasoning_text", i)
deltaRC := fmt.Sprintf("choices.%d.delta.reasoning_content", i)
if rt := gjson.GetBytes(data, deltaRT); rt.Exists() && rt.String() != "" {
if rc := gjson.GetBytes(data, deltaRC); !rc.Exists() || rc.Type == gjson.Null || rc.String() == "" {
data, _ = sjson.SetBytes(data, deltaRC, rt.String())
}
}
}
return data
}
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool { func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
if sourceFormat.String() == "openai-response" { if sourceFormat.String() == "openai-response" {
return true return true
@@ -596,12 +795,7 @@ func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
} }
func containsEndpoint(endpoints []string, endpoint string) bool { func containsEndpoint(endpoints []string, endpoint string) bool {
for _, item := range endpoints { return slices.Contains(endpoints, endpoint)
if item == endpoint {
return true
}
}
return false
} }
// flattenAssistantContent converts assistant message content from array format // flattenAssistantContent converts assistant message content from array format
@@ -856,6 +1050,32 @@ func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
return body return body
} }
// applyGitHubCopilotResponsesDefaults sets required fields for the Responses API
// that both vscode-copilot-chat and pi-ai always include.
//
// References:
// - vscode-copilot-chat: src/platform/endpoint/node/responsesApi.ts
// - pi-ai (badlogic/pi-mono): packages/ai/src/providers/openai-responses.ts
func applyGitHubCopilotResponsesDefaults(body []byte) []byte {
// store: false — prevents request/response storage
if !gjson.GetBytes(body, "store").Exists() {
body, _ = sjson.SetBytes(body, "store", false)
}
// include: ["reasoning.encrypted_content"] — enables reasoning content
// reuse across turns, avoiding redundant computation
if !gjson.GetBytes(body, "include").Exists() {
body, _ = sjson.SetRawBytes(body, "include", []byte(`["reasoning.encrypted_content"]`))
}
// If reasoning.effort is set but reasoning.summary is not, default to "auto"
if gjson.GetBytes(body, "reasoning.effort").Exists() && !gjson.GetBytes(body, "reasoning.summary").Exists() {
body, _ = sjson.SetBytes(body, "reasoning.summary", "auto")
}
return body
}
func normalizeGitHubCopilotResponsesTools(body []byte) []byte { func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools") tools := gjson.GetBytes(body, "tools")
if tools.Exists() { if tools.Exists() {
@@ -1406,6 +1626,21 @@ func FetchGitHubCopilotModels(ctx context.Context, auth *cliproxyauth.Auth, cfg
m.MaxCompletionTokens = defaultCopilotMaxCompletionTokens m.MaxCompletionTokens = defaultCopilotMaxCompletionTokens
} }
// Override with real limits from the Copilot API when available.
// The API returns per-account limits (individual vs business) under
// capabilities.limits, which are more accurate than our static
// fallback values. We use max_prompt_tokens as ContextLength because
// that's the hard limit the Copilot API enforces on prompt size —
// exceeding it triggers "prompt token count exceeds the limit" errors.
if limits := entry.Limits(); limits != nil {
if limits.MaxPromptTokens > 0 {
m.ContextLength = limits.MaxPromptTokens
}
if limits.MaxOutputTokens > 0 {
m.MaxCompletionTokens = limits.MaxOutputTokens
}
}
models = append(models, m) models = append(models, m)
} }

View File

@@ -1,11 +1,14 @@
package executor package executor
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -72,26 +75,39 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
} }
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) { func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
t.Parallel() // Not parallel: shares global model registry with DynamicRegistryWinsOverStatic.
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") { if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected responses-only registry model to use /responses") t.Fatal("expected responses-only registry model to use /responses")
} }
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
t.Fatal("expected responses-only registry model to use /responses")
}
} }
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) { func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
t.Parallel() // Not parallel: mutates global model registry, conflicts with RegistryResponsesOnlyModel.
reg := registry.GetGlobalRegistry() reg := registry.GetGlobalRegistry()
clientID := "github-copilot-test-client" clientID := "github-copilot-test-client"
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{ reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
ID: "gpt-5.4", {
SupportedEndpoints: []string{"/chat/completions", "/responses"}, ID: "gpt-5.4",
}}) SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
{
ID: "gpt-5.4-mini",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
})
defer reg.UnregisterClient(clientID) defer reg.UnregisterClient(clientID)
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") { if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback") t.Fatal("expected dynamic registry definition to take precedence over static fallback")
} }
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
}
} }
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) { func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
@@ -238,14 +254,14 @@ func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing
t.Parallel() t.Parallel()
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`) resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp) out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "type").String() != "message" { if gjson.GetBytes(out, "type").String() != "message" {
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String()) t.Fatalf("type = %q, want message", gjson.GetBytes(out, "type").String())
} }
if gjson.Get(out, "content.0.type").String() != "text" { if gjson.GetBytes(out, "content.0.type").String() != "text" {
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String()) t.Fatalf("content.0.type = %q, want text", gjson.GetBytes(out, "content.0.type").String())
} }
if gjson.Get(out, "content.0.text").String() != "hello" { if gjson.GetBytes(out, "content.0.text").String() != "hello" {
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String()) t.Fatalf("content.0.text = %q, want hello", gjson.GetBytes(out, "content.0.text").String())
} }
} }
@@ -253,14 +269,14 @@ func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *test
t.Parallel() t.Parallel()
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`) resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp) out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "content.0.type").String() != "tool_use" { if gjson.GetBytes(out, "content.0.type").String() != "tool_use" {
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String()) t.Fatalf("content.0.type = %q, want tool_use", gjson.GetBytes(out, "content.0.type").String())
} }
if gjson.Get(out, "content.0.name").String() != "sum" { if gjson.GetBytes(out, "content.0.name").String() != "sum" {
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String()) t.Fatalf("content.0.name = %q, want sum", gjson.GetBytes(out, "content.0.name").String())
} }
if gjson.Get(out, "stop_reason").String() != "tool_use" { if gjson.GetBytes(out, "stop_reason").String() != "tool_use" {
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String()) t.Fatalf("stop_reason = %q, want tool_use", gjson.GetBytes(out, "stop_reason").String())
} }
} }
@@ -269,18 +285,24 @@ func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.
var param any var param any
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), &param) created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), &param)
if len(created) == 0 || !strings.Contains(created[0], "message_start") { if len(created) == 0 || !strings.Contains(string(created[0]), "message_start") {
t.Fatalf("created events = %#v, want message_start", created) t.Fatalf("created events = %#v, want message_start", created)
} }
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), &param) delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), &param)
joinedDelta := strings.Join(delta, "") var joinedDelta string
for _, d := range delta {
joinedDelta += string(d)
}
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") { if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta) t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
} }
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), &param) completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), &param)
joinedCompleted := strings.Join(completed, "") var joinedCompleted string
for _, c := range completed {
joinedCompleted += string(c)
}
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") { if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed) t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
} }
@@ -299,15 +321,17 @@ func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
} }
} }
func TestApplyHeaders_XInitiator_UserWhenLastRoleIsUser(t *testing.T) { func TestApplyHeaders_XInitiator_AgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
t.Parallel() t.Parallel()
e := &GitHubCopilotExecutor{} e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Last role governs the initiator decision. // When the last role is "user" and the message contains tool_result content,
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`) // the request is a continuation (e.g. Claude tool result translated to a
// synthetic user message). Should be "agent".
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu1","content":"file contents..."}]}]}`)
e.applyHeaders(req, "token", body) e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" { if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want user (last role is user)", got) t.Fatalf("X-Initiator = %q, want agent (last user contains tool_result)", got)
} }
} }
@@ -315,10 +339,11 @@ func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
t.Parallel() t.Parallel()
e := &GitHubCopilotExecutor{} e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// When the last message has role "tool", it's clearly agent-initiated.
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`) body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
e.applyHeaders(req, "token", body) e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" { if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got) t.Fatalf("X-Initiator = %q, want agent (last role is tool)", got)
} }
} }
@@ -333,14 +358,15 @@ func TestApplyHeaders_XInitiator_InputArrayLastAssistantMessage(t *testing.T) {
} }
} }
func TestApplyHeaders_XInitiator_InputArrayLastUserMessage(t *testing.T) { func TestApplyHeaders_XInitiator_InputArrayAgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
t.Parallel() t.Parallel()
e := &GitHubCopilotExecutor{} e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil) req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Responses API: last item is user-role but history contains assistant → agent.
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`) body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
e.applyHeaders(req, "token", body) e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" { if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want user (last role is user)", got) t.Fatalf("X-Initiator = %q, want agent (history has assistant)", got)
} }
} }
@@ -355,6 +381,33 @@ func TestApplyHeaders_XInitiator_InputArrayLastFunctionCallOutput(t *testing.T)
} }
} }
func TestApplyHeaders_XInitiator_UserInMultiTurnNoTools(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Genuine multi-turn: user → assistant (plain text) → user follow-up.
// No tool messages → should be "user" (not a false-positive).
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"what is 2+2?"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user (genuine multi-turn, no tools)", got)
}
}
func TestApplyHeaders_XInitiator_UserFollowUpAfterToolHistory(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// User follow-up after a completed tool-use conversation.
// The last message is a genuine user question — should be "user", not "agent".
// This aligns with opencode's behavior: only active tool loops are agent-initiated.
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":[{"type":"tool_use","id":"tu1","name":"Read","input":{}}]},{"role":"tool","tool_call_id":"tu1","content":"file data"},{"role":"assistant","content":"I read the file."},{"role":"user","content":"What did we do so far?"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user (genuine follow-up after tool history)", got)
}
}
// --- Tests for x-github-api-version header (Problem M) --- // --- Tests for x-github-api-version header (Problem M) ---
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) { func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
@@ -401,3 +454,364 @@ func TestDetectVisionContent_NoMessages(t *testing.T) {
t.Fatal("expected no vision content when messages field is absent") t.Fatal("expected no vision content when messages field is absent")
} }
} }
// --- Tests for applyGitHubCopilotResponsesDefaults ---
func TestApplyGitHubCopilotResponsesDefaults_SetsAllDefaults(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"hello","reasoning":{"effort":"medium"}}`)
got := applyGitHubCopilotResponsesDefaults(body)
if gjson.GetBytes(got, "store").Bool() != false {
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
}
inc := gjson.GetBytes(got, "include")
if !inc.IsArray() || inc.Array()[0].String() != "reasoning.encrypted_content" {
t.Fatalf("include = %s, want [\"reasoning.encrypted_content\"]", inc.Raw)
}
if gjson.GetBytes(got, "reasoning.summary").String() != "auto" {
t.Fatalf("reasoning.summary = %q, want auto", gjson.GetBytes(got, "reasoning.summary").String())
}
}
func TestApplyGitHubCopilotResponsesDefaults_DoesNotOverrideExisting(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"hello","store":true,"include":["other"],"reasoning":{"effort":"high","summary":"concise"}}`)
got := applyGitHubCopilotResponsesDefaults(body)
if gjson.GetBytes(got, "store").Bool() != true {
t.Fatalf("store should not be overridden, got %s", gjson.GetBytes(got, "store").Raw)
}
if gjson.GetBytes(got, "include").Array()[0].String() != "other" {
t.Fatalf("include should not be overridden, got %s", gjson.GetBytes(got, "include").Raw)
}
if gjson.GetBytes(got, "reasoning.summary").String() != "concise" {
t.Fatalf("reasoning.summary should not be overridden, got %q", gjson.GetBytes(got, "reasoning.summary").String())
}
}
func TestApplyGitHubCopilotResponsesDefaults_NoReasoningEffort(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"hello"}`)
got := applyGitHubCopilotResponsesDefaults(body)
if gjson.GetBytes(got, "store").Bool() != false {
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
}
// reasoning.summary should NOT be set when reasoning.effort is absent
if gjson.GetBytes(got, "reasoning.summary").Exists() {
t.Fatalf("reasoning.summary should not be set when reasoning.effort is absent, got %q", gjson.GetBytes(got, "reasoning.summary").String())
}
}
// --- Tests for normalizeGitHubCopilotReasoningField ---
func TestNormalizeReasoningField_NonStreaming(t *testing.T) {
t.Parallel()
data := []byte(`{"choices":[{"message":{"content":"hello","reasoning_text":"I think..."}}]}`)
got := normalizeGitHubCopilotReasoningField(data)
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
if rc != "I think..." {
t.Fatalf("reasoning_content = %q, want %q", rc, "I think...")
}
}
func TestNormalizeReasoningField_Streaming(t *testing.T) {
t.Parallel()
data := []byte(`{"choices":[{"delta":{"reasoning_text":"thinking delta"}}]}`)
got := normalizeGitHubCopilotReasoningField(data)
rc := gjson.GetBytes(got, "choices.0.delta.reasoning_content").String()
if rc != "thinking delta" {
t.Fatalf("reasoning_content = %q, want %q", rc, "thinking delta")
}
}
func TestNormalizeReasoningField_PreservesExistingReasoningContent(t *testing.T) {
t.Parallel()
data := []byte(`{"choices":[{"message":{"reasoning_text":"old","reasoning_content":"existing"}}]}`)
got := normalizeGitHubCopilotReasoningField(data)
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
if rc != "existing" {
t.Fatalf("reasoning_content = %q, want %q (should not overwrite)", rc, "existing")
}
}
func TestNormalizeReasoningField_MultiChoice(t *testing.T) {
t.Parallel()
data := []byte(`{"choices":[{"message":{"reasoning_text":"thought-0"}},{"message":{"reasoning_text":"thought-1"}}]}`)
got := normalizeGitHubCopilotReasoningField(data)
rc0 := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
rc1 := gjson.GetBytes(got, "choices.1.message.reasoning_content").String()
if rc0 != "thought-0" {
t.Fatalf("choices[0].reasoning_content = %q, want %q", rc0, "thought-0")
}
if rc1 != "thought-1" {
t.Fatalf("choices[1].reasoning_content = %q, want %q", rc1, "thought-1")
}
}
func TestNormalizeReasoningField_NoChoices(t *testing.T) {
t.Parallel()
data := []byte(`{"id":"chatcmpl-123"}`)
got := normalizeGitHubCopilotReasoningField(data)
if string(got) != string(data) {
t.Fatalf("expected no change, got %s", string(got))
}
}
func TestApplyHeaders_OpenAIIntentValue(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
e.applyHeaders(req, "token", nil)
if got := req.Header.Get("Openai-Intent"); got != "conversation-edits" {
t.Fatalf("Openai-Intent = %q, want conversation-edits", got)
}
}
// --- Tests for CountTokens (local tiktoken estimation) ---
func TestCountTokens_ReturnsPositiveCount(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, world!"}]}`)
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-4o",
Payload: body,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("CountTokens() error: %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("CountTokens() returned empty payload")
}
// The response should contain a positive token count.
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
if tokens <= 0 {
t.Fatalf("expected positive token count, got %d", tokens)
}
}
func TestCountTokens_ClaudeSourceFormatTranslates(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
body := []byte(`{"model":"claude-sonnet-4","messages":[{"role":"user","content":"Tell me a joke"}],"max_tokens":1024}`)
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "claude-sonnet-4",
Payload: body,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("CountTokens() error: %v", err)
}
// Claude source format → should get input_tokens in response
inputTokens := gjson.GetBytes(resp.Payload, "input_tokens").Int()
if inputTokens <= 0 {
// Fallback: check usage.prompt_tokens (depends on translator registration)
promptTokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
if promptTokens <= 0 {
t.Fatalf("expected positive token count, got payload: %s", resp.Payload)
}
}
}
func TestCountTokens_EmptyPayload(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-4o",
Payload: []byte(`{"model":"gpt-4o","messages":[]}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("CountTokens() error: %v", err)
}
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
// Empty messages should return 0 tokens.
if tokens != 0 {
t.Fatalf("expected 0 tokens for empty messages, got %d", tokens)
}
}
func TestStripUnsupportedBetas_RemovesContext1M(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"claude-opus-4.6","betas":["interleaved-thinking-2025-05-14","context-1m-2025-08-07","claude-code-20250219"],"messages":[]}`)
result := stripUnsupportedBetas(body)
betas := gjson.GetBytes(result, "betas")
if !betas.Exists() {
t.Fatal("betas field should still exist after stripping")
}
for _, item := range betas.Array() {
if item.String() == "context-1m-2025-08-07" {
t.Fatal("context-1m-2025-08-07 should have been stripped")
}
}
// Other betas should be preserved
found := false
for _, item := range betas.Array() {
if item.String() == "interleaved-thinking-2025-05-14" {
found = true
}
}
if !found {
t.Fatal("other betas should be preserved")
}
}
func TestStripUnsupportedBetas_NoBetasField(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"gpt-4o","messages":[]}`)
result := stripUnsupportedBetas(body)
// Should be unchanged
if string(result) != string(body) {
t.Fatalf("body should be unchanged when no betas field exists, got %s", string(result))
}
}
func TestStripUnsupportedBetas_MetadataBetas(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"claude-opus-4.6","metadata":{"betas":["context-1m-2025-08-07","other-beta"]},"messages":[]}`)
result := stripUnsupportedBetas(body)
betas := gjson.GetBytes(result, "metadata.betas")
if !betas.Exists() {
t.Fatal("metadata.betas field should still exist after stripping")
}
for _, item := range betas.Array() {
if item.String() == "context-1m-2025-08-07" {
t.Fatal("context-1m-2025-08-07 should have been stripped from metadata.betas")
}
}
if betas.Array()[0].String() != "other-beta" {
t.Fatal("other betas in metadata.betas should be preserved")
}
}
func TestStripUnsupportedBetas_AllBetasStripped(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"claude-opus-4.6","betas":["context-1m-2025-08-07"],"messages":[]}`)
result := stripUnsupportedBetas(body)
betas := gjson.GetBytes(result, "betas")
if betas.Exists() {
t.Fatal("betas field should be deleted when all betas are stripped")
}
}
func TestCopilotModelEntry_Limits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
capabilities map[string]any
wantNil bool
wantPrompt int
wantOutput int
wantContext int
}{
{
name: "nil capabilities",
capabilities: nil,
wantNil: true,
},
{
name: "no limits key",
capabilities: map[string]any{"family": "claude-opus-4.6"},
wantNil: true,
},
{
name: "limits is not a map",
capabilities: map[string]any{"limits": "invalid"},
wantNil: true,
},
{
name: "all zero values",
capabilities: map[string]any{
"limits": map[string]any{
"max_context_window_tokens": float64(0),
"max_prompt_tokens": float64(0),
"max_output_tokens": float64(0),
},
},
wantNil: true,
},
{
name: "individual account limits (128K prompt)",
capabilities: map[string]any{
"limits": map[string]any{
"max_context_window_tokens": float64(144000),
"max_prompt_tokens": float64(128000),
"max_output_tokens": float64(64000),
},
},
wantNil: false,
wantPrompt: 128000,
wantOutput: 64000,
wantContext: 144000,
},
{
name: "business account limits (168K prompt)",
capabilities: map[string]any{
"limits": map[string]any{
"max_context_window_tokens": float64(200000),
"max_prompt_tokens": float64(168000),
"max_output_tokens": float64(32000),
},
},
wantNil: false,
wantPrompt: 168000,
wantOutput: 32000,
wantContext: 200000,
},
{
name: "partial limits (only prompt)",
capabilities: map[string]any{
"limits": map[string]any{
"max_prompt_tokens": float64(128000),
},
},
wantNil: false,
wantPrompt: 128000,
wantOutput: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
entry := copilotauth.CopilotModelEntry{
ID: "claude-opus-4.6",
Capabilities: tt.capabilities,
}
limits := entry.Limits()
if tt.wantNil {
if limits != nil {
t.Fatalf("expected nil limits, got %+v", limits)
}
return
}
if limits == nil {
t.Fatal("expected non-nil limits, got nil")
}
if limits.MaxPromptTokens != tt.wantPrompt {
t.Errorf("MaxPromptTokens = %d, want %d", limits.MaxPromptTokens, tt.wantPrompt)
}
if limits.MaxOutputTokens != tt.wantOutput {
t.Errorf("MaxOutputTokens = %d, want %d", limits.MaxOutputTokens, tt.wantOutput)
}
if tt.wantContext > 0 && limits.MaxContextWindowTokens != tt.wantContext {
t.Errorf("MaxContextWindowTokens = %d, want %d", limits.MaxContextWindowTokens, tt.wantContext)
}
})
}
}

View File

@@ -30,12 +30,20 @@ const (
gitLabChatEndpoint = "/api/v4/chat/completions" gitLabChatEndpoint = "/api/v4/chat/completions"
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions" gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming" gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming"
gitLabContext1MBeta = "context-1m-2025-08-07"
gitLabNativeUserAgent = "CLIProxyAPIPlus/GitLab-Duo"
) )
type GitLabExecutor struct { type GitLabExecutor struct {
cfg *config.Config cfg *config.Config
} }
type gitLabCatalogModel struct {
ID string
DisplayName string
Provider string
}
type gitLabPrompt struct { type gitLabPrompt struct {
Instruction string Instruction string
FileName string FileName string
@@ -53,6 +61,23 @@ type gitLabOpenAIStreamState struct {
Finished bool Finished bool
} }
var gitLabAgenticCatalog = []gitLabCatalogModel{
{ID: "duo-chat-gpt-5-1", DisplayName: "GitLab Duo (GPT-5.1)", Provider: "openai"},
{ID: "duo-chat-opus-4-6", DisplayName: "GitLab Duo (Claude Opus 4.6)", Provider: "anthropic"},
{ID: "duo-chat-opus-4-5", DisplayName: "GitLab Duo (Claude Opus 4.5)", Provider: "anthropic"},
{ID: "duo-chat-sonnet-4-6", DisplayName: "GitLab Duo (Claude Sonnet 4.6)", Provider: "anthropic"},
{ID: "duo-chat-sonnet-4-5", DisplayName: "GitLab Duo (Claude Sonnet 4.5)", Provider: "anthropic"},
{ID: "duo-chat-gpt-5-mini", DisplayName: "GitLab Duo (GPT-5 Mini)", Provider: "openai"},
{ID: "duo-chat-gpt-5-2", DisplayName: "GitLab Duo (GPT-5.2)", Provider: "openai"},
{ID: "duo-chat-gpt-5-2-codex", DisplayName: "GitLab Duo (GPT-5.2 Codex)", Provider: "openai"},
{ID: "duo-chat-gpt-5-codex", DisplayName: "GitLab Duo (GPT-5 Codex)", Provider: "openai"},
{ID: "duo-chat-haiku-4-5", DisplayName: "GitLab Duo (Claude Haiku 4.5)", Provider: "anthropic"},
}
var gitLabModelAliases = map[string]string{
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
}
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor { func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
return &GitLabExecutor{cfg: cfg} return &GitLabExecutor{cfg: cfg}
} }
@@ -249,12 +274,12 @@ func (e *GitLabExecutor) nativeGateway(
auth *cliproxyauth.Auth, auth *cliproxyauth.Auth,
req cliproxyexecutor.Request, req cliproxyexecutor.Request,
) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) { ) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok { if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, req.Model); ok {
nativeReq := req nativeReq := req
nativeReq.Model = gitLabResolvedModel(auth, req.Model) nativeReq.Model = gitLabResolvedModel(auth, req.Model)
return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true
} }
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok { if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, req.Model); ok {
nativeReq := req nativeReq := req
nativeReq.Model = gitLabResolvedModel(auth, req.Model) nativeReq.Model = gitLabResolvedModel(auth, req.Model)
return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true
@@ -263,10 +288,10 @@ func (e *GitLabExecutor) nativeGateway(
} }
func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) { func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok { if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, ""); ok {
return NewClaudeExecutor(e.cfg), nativeAuth return NewClaudeExecutor(e.cfg), nativeAuth
} }
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok { if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, ""); ok {
return NewCodexExecutor(e.cfg), nativeAuth return NewCodexExecutor(e.cfg), nativeAuth
} }
return nil, nil return nil, nil
@@ -664,7 +689,7 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
if auth != nil { if auth != nil {
util.ApplyCustomHeadersFromAttrs(req, auth.Attributes) util.ApplyCustomHeadersFromAttrs(req, auth.Attributes)
} }
for key, value := range gitLabGatewayHeaders(auth) { for key, value := range gitLabGatewayHeaders(auth, "") {
if key == "" || value == "" { if key == "" || value == "" {
continue continue
} }
@@ -672,34 +697,40 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
} }
} }
func gitLabGatewayHeaders(auth *cliproxyauth.Auth) map[string]string { func gitLabGatewayHeaders(auth *cliproxyauth.Auth, targetProvider string) map[string]string {
if auth == nil || auth.Metadata == nil {
return nil
}
raw, ok := auth.Metadata["duo_gateway_headers"]
if !ok {
return nil
}
out := make(map[string]string) out := make(map[string]string)
switch typed := raw.(type) { if auth != nil && auth.Metadata != nil {
case map[string]string: raw, ok := auth.Metadata["duo_gateway_headers"]
for key, value := range typed { if ok {
key = strings.TrimSpace(key) switch typed := raw.(type) {
value = strings.TrimSpace(value) case map[string]string:
if key != "" && value != "" { for key, value := range typed {
out[key] = value key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key != "" && value != "" {
out[key] = value
}
}
case map[string]any:
for key, value := range typed {
key = strings.TrimSpace(key)
if key == "" {
continue
}
strValue := strings.TrimSpace(fmt.Sprint(value))
if strValue != "" {
out[key] = strValue
}
}
} }
} }
case map[string]any: }
for key, value := range typed { if _, ok := out["User-Agent"]; !ok {
key = strings.TrimSpace(key) out["User-Agent"] = gitLabNativeUserAgent
if key == "" { }
continue if strings.EqualFold(strings.TrimSpace(targetProvider), "openai") {
} if _, ok := out["anthropic-beta"]; !ok {
strValue := strings.TrimSpace(fmt.Sprint(value)) out["anthropic-beta"] = gitLabContext1MBeta
if strValue != "" {
out[key] = strValue
}
} }
} }
if len(out) == 0 { if len(out) == 0 {
@@ -989,8 +1020,8 @@ func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64)
return promptTokens, int64(completionCount) return promptTokens, int64(completionCount)
} }
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) { func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
if !gitLabUsesAnthropicGateway(auth) { if !gitLabUsesAnthropicGateway(auth, requestedModel) {
return nil, false return nil, false
} }
baseURL := gitLabAnthropicGatewayBaseURL(auth) baseURL := gitLabAnthropicGatewayBaseURL(auth)
@@ -1006,7 +1037,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
} }
nativeAuth.Attributes["api_key"] = token nativeAuth.Attributes["api_key"] = token
nativeAuth.Attributes["base_url"] = baseURL nativeAuth.Attributes["base_url"] = baseURL
for key, value := range gitLabGatewayHeaders(auth) { nativeAuth.Attributes["gitlab_duo_force_context_1m"] = "true"
for key, value := range gitLabGatewayHeaders(auth, "anthropic") {
if key == "" || value == "" { if key == "" || value == "" {
continue continue
} }
@@ -1015,8 +1047,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
return nativeAuth, true return nativeAuth, true
} }
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) { func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
if !gitLabUsesOpenAIGateway(auth) { if !gitLabUsesOpenAIGateway(auth, requestedModel) {
return nil, false return nil, false
} }
baseURL := gitLabOpenAIGatewayBaseURL(auth) baseURL := gitLabOpenAIGatewayBaseURL(auth)
@@ -1032,7 +1064,7 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
} }
nativeAuth.Attributes["api_key"] = token nativeAuth.Attributes["api_key"] = token
nativeAuth.Attributes["base_url"] = baseURL nativeAuth.Attributes["base_url"] = baseURL
for key, value := range gitLabGatewayHeaders(auth) { for key, value := range gitLabGatewayHeaders(auth, "openai") {
if key == "" || value == "" { if key == "" || value == "" {
continue continue
} }
@@ -1041,34 +1073,41 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
return nativeAuth, true return nativeAuth, true
} }
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth) bool { func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
if auth == nil || auth.Metadata == nil { if auth == nil || auth.Metadata == nil {
return false return false
} }
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider")) provider := gitLabGatewayProvider(auth, requestedModel)
if provider == "" {
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
provider = inferGitLabProviderFromModel(modelName)
}
return provider == "anthropic" && return provider == "anthropic" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" && gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != "" gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
} }
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth) bool { func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
if auth == nil || auth.Metadata == nil { if auth == nil || auth.Metadata == nil {
return false return false
} }
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider")) provider := gitLabGatewayProvider(auth, requestedModel)
if provider == "" {
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
provider = inferGitLabProviderFromModel(modelName)
}
return provider == "openai" && return provider == "openai" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" && gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != "" gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
} }
func gitLabGatewayProvider(auth *cliproxyauth.Auth, requestedModel string) string {
modelName := strings.TrimSpace(gitLabResolvedModel(auth, requestedModel))
if provider := inferGitLabProviderFromModel(modelName); provider != "" {
return provider
}
if auth == nil || auth.Metadata == nil {
return ""
}
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
if provider == "" {
provider = inferGitLabProviderFromModel(gitLabMetadataString(auth.Metadata, "model_name"))
}
return provider
}
func inferGitLabProviderFromModel(model string) string { func inferGitLabProviderFromModel(model string) string {
model = strings.ToLower(strings.TrimSpace(model)) model = strings.ToLower(strings.TrimSpace(model))
switch { switch {
@@ -1151,6 +1190,9 @@ func gitLabBaseURL(auth *cliproxyauth.Auth) string {
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string { func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName) requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") { if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
if mapped, ok := gitLabModelAliases[strings.ToLower(requested)]; ok && strings.TrimSpace(mapped) != "" {
return mapped
}
return requested return requested
} }
if auth != nil && auth.Metadata != nil { if auth != nil && auth.Metadata != nil {
@@ -1277,8 +1319,8 @@ func gitLabAuthKind(method string) string {
} }
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo { func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
models := make([]*registry.ModelInfo, 0, 4) models := make([]*registry.ModelInfo, 0, len(gitLabAgenticCatalog)+4)
seen := make(map[string]struct{}, 4) seen := make(map[string]struct{}, len(gitLabAgenticCatalog)+4)
addModel := func(id, displayName, provider string) { addModel := func(id, displayName, provider string) {
id = strings.TrimSpace(id) id = strings.TrimSpace(id)
if id == "" { if id == "" {
@@ -1302,6 +1344,18 @@ func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
} }
addModel("gitlab-duo", "GitLab Duo", "gitlab") addModel("gitlab-duo", "GitLab Duo", "gitlab")
for _, model := range gitLabAgenticCatalog {
addModel(model.ID, model.DisplayName, model.Provider)
}
for alias, upstream := range gitLabModelAliases {
target := strings.TrimSpace(upstream)
displayName := "GitLab Duo Alias"
provider := strings.TrimSpace(inferGitLabProviderFromModel(target))
if provider != "" {
displayName = fmt.Sprintf("GitLab Duo Alias (%s)", provider)
}
addModel(alias, displayName, provider)
}
if auth == nil { if auth == nil {
return models return models
} }

View File

@@ -217,6 +217,69 @@ func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
} }
} }
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
var gotPath string
var gotModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
gotBetaHeader = r.Header.Get("anthropic-beta")
gotUserAgent = r.Header.Get("User-Agent")
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "duo-chat-gpt-5-codex",
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/v1/proxy/openai/v1/responses" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
}
if gotBetaHeader != gitLabContext1MBeta {
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
}
if gotUserAgent != gitLabNativeUserAgent {
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
}
if gotModel != "duo-chat-gpt-5-codex" {
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
}
}
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) { func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
@@ -251,13 +314,12 @@ func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
ID: "gitlab-auth.json", ID: "gitlab-auth.json",
Provider: "gitlab", Provider: "gitlab",
Metadata: map[string]any{ Metadata: map[string]any{
"base_url": srv.URL, "base_url": srv.URL,
"access_token": "oauth-access", "access_token": "oauth-access",
"refresh_token": "oauth-refresh", "refresh_token": "oauth-refresh",
"oauth_client_id": "client-id", "oauth_client_id": "client-id",
"oauth_client_secret": "client-secret", "auth_method": "oauth",
"auth_method": "oauth", "oauth_expires_at": "2000-01-01T00:00:00Z",
"oauth_expires_at": "2000-01-01T00:00:00Z",
}, },
} }
@@ -397,9 +459,11 @@ func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
} }
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) { func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
var gotPath string var gotPath, gotBetaHeader, gotUserAgent string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path gotPath = r.URL.Path
gotBetaHeader = r.Header.Get("Anthropic-Beta")
gotUserAgent = r.Header.Get("User-Agent")
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: message_start\n")) _, _ = w.Write([]byte("event: message_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n")) _, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
@@ -441,6 +505,12 @@ func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
if gotPath != "/v1/proxy/anthropic/v1/messages" { if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages") t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
} }
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
}
if gotUserAgent != gitLabNativeUserAgent {
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
}
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") { if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n")) t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
} }

View File

@@ -1,11 +1,11 @@
package executor package helps
import ( import (
"sync" "sync"
"time" "time"
) )
type codexCache struct { type CodexCache struct {
ID string ID string
Expire time.Time Expire time.Time
} }
@@ -13,7 +13,7 @@ type codexCache struct {
// codexCacheMap stores prompt cache IDs keyed by model+user_id. // codexCacheMap stores prompt cache IDs keyed by model+user_id.
// Protected by codexCacheMu. Entries expire after 1 hour. // Protected by codexCacheMu. Entries expire after 1 hour.
var ( var (
codexCacheMap = make(map[string]codexCache) codexCacheMap = make(map[string]CodexCache)
codexCacheMu sync.RWMutex codexCacheMu sync.RWMutex
) )
@@ -50,20 +50,20 @@ func purgeExpiredCodexCache() {
} }
} }
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired. // GetCodexCache retrieves a cached entry, returning ok=false if not found or expired.
func getCodexCache(key string) (codexCache, bool) { func GetCodexCache(key string) (CodexCache, bool) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup) codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.RLock() codexCacheMu.RLock()
cache, ok := codexCacheMap[key] cache, ok := codexCacheMap[key]
codexCacheMu.RUnlock() codexCacheMu.RUnlock()
if !ok || cache.Expire.Before(time.Now()) { if !ok || cache.Expire.Before(time.Now()) {
return codexCache{}, false return CodexCache{}, false
} }
return cache, true return cache, true
} }
// setCodexCache stores a cache entry. // SetCodexCache stores a cache entry.
func setCodexCache(key string, cache codexCache) { func SetCodexCache(key string, cache CodexCache) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup) codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.Lock() codexCacheMu.Lock()
codexCacheMap[key] = cache codexCacheMap[key] = cache

View File

@@ -0,0 +1,38 @@
package helps
import "github.com/tidwall/gjson"
var defaultClaudeBuiltinToolNames = []string{
"web_search",
"code_execution",
"text_editor",
"computer",
}
func newClaudeBuiltinToolRegistry() map[string]bool {
registry := make(map[string]bool, len(defaultClaudeBuiltinToolNames))
for _, name := range defaultClaudeBuiltinToolNames {
registry[name] = true
}
return registry
}
func AugmentClaudeBuiltinToolRegistry(body []byte, registry map[string]bool) map[string]bool {
if registry == nil {
registry = newClaudeBuiltinToolRegistry()
}
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() || !tools.IsArray() {
return registry
}
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "" {
return true
}
if name := tool.Get("name").String(); name != "" {
registry[name] = true
}
return true
})
return registry
}

View File

@@ -0,0 +1,32 @@
package helps
import "testing"
func TestClaudeBuiltinToolRegistry_DefaultSeedFallback(t *testing.T) {
registry := AugmentClaudeBuiltinToolRegistry(nil, nil)
for _, name := range defaultClaudeBuiltinToolNames {
if !registry[name] {
t.Fatalf("default builtin %q missing from fallback registry", name)
}
}
}
func TestClaudeBuiltinToolRegistry_AugmentsTypedBuiltinsFromBody(t *testing.T) {
registry := AugmentClaudeBuiltinToolRegistry([]byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"type": "custom_builtin_20250401", "name": "special_builtin"},
{"name": "Read"}
]
}`), nil)
if !registry["web_search"] {
t.Fatal("expected default typed builtin web_search in registry")
}
if !registry["special_builtin"] {
t.Fatal("expected typed builtin from body to be added to registry")
}
if registry["Read"] {
t.Fatal("expected untyped custom tool to stay out of builtin registry")
}
}

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"crypto/sha256" "crypto/sha256"
@@ -32,7 +32,7 @@ var (
claudeDeviceProfileCacheMu sync.RWMutex claudeDeviceProfileCacheMu sync.RWMutex
claudeDeviceProfileCacheCleanupOnce sync.Once claudeDeviceProfileCacheCleanupOnce sync.Once
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile) ClaudeDeviceProfileBeforeCandidateStore func(ClaudeDeviceProfile)
) )
type claudeCLIVersion struct { type claudeCLIVersion struct {
@@ -63,29 +63,43 @@ func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
} }
} }
type claudeDeviceProfile struct { type ClaudeDeviceProfile struct {
UserAgent string UserAgent string
PackageVersion string PackageVersion string
RuntimeVersion string RuntimeVersion string
OS string OS string
Arch string Arch string
Version claudeCLIVersion version claudeCLIVersion
HasVersion bool hasVersion bool
} }
type claudeDeviceProfileCacheEntry struct { type claudeDeviceProfileCacheEntry struct {
profile claudeDeviceProfile profile ClaudeDeviceProfile
expire time.Time expire time.Time
} }
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool { func ClaudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil { if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
return false return false
} }
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
} }
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile { func ResetClaudeDeviceProfileCache() {
claudeDeviceProfileCacheMu.Lock()
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu.Unlock()
}
func MapStainlessOS() string {
return mapStainlessOS()
}
func MapStainlessArch() string {
return mapStainlessArch()
}
func defaultClaudeDeviceProfile(cfg *config.Config) ClaudeDeviceProfile {
hdrDefault := func(cfgVal, fallback string) string { hdrDefault := func(cfgVal, fallback string) string {
if strings.TrimSpace(cfgVal) != "" { if strings.TrimSpace(cfgVal) != "" {
return strings.TrimSpace(cfgVal) return strings.TrimSpace(cfgVal)
@@ -98,7 +112,7 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
hd = cfg.ClaudeHeaderDefaults hd = cfg.ClaudeHeaderDefaults
} }
profile := claudeDeviceProfile{ profile := ClaudeDeviceProfile{
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent), UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion), PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion), RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
@@ -106,8 +120,8 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch), Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
} }
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok { if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
profile.Version = version profile.version = version
profile.HasVersion = true profile.hasVersion = true
} }
return profile return profile
} }
@@ -162,17 +176,17 @@ func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
} }
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool { func shouldUpgradeClaudeDeviceProfile(candidate, current ClaudeDeviceProfile) bool {
if candidate.UserAgent == "" || !candidate.HasVersion { if candidate.UserAgent == "" || !candidate.hasVersion {
return false return false
} }
if current.UserAgent == "" || !current.HasVersion { if current.UserAgent == "" || !current.hasVersion {
return true return true
} }
return candidate.Version.Compare(current.Version) > 0 return candidate.version.Compare(current.version) > 0
} }
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile { func pinClaudeDeviceProfilePlatform(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
profile.OS = baseline.OS profile.OS = baseline.OS
profile.Arch = baseline.Arch profile.Arch = baseline.Arch
return profile return profile
@@ -180,38 +194,38 @@ func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claud
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current // normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
// baseline platform and enforces the baseline software fingerprint as a floor. // baseline platform and enforces the baseline software fingerprint as a floor.
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile { func normalizeClaudeDeviceProfile(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
profile = pinClaudeDeviceProfilePlatform(profile, baseline) profile = pinClaudeDeviceProfilePlatform(profile, baseline)
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) { if profile.UserAgent == "" || !profile.hasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
profile.UserAgent = baseline.UserAgent profile.UserAgent = baseline.UserAgent
profile.PackageVersion = baseline.PackageVersion profile.PackageVersion = baseline.PackageVersion
profile.RuntimeVersion = baseline.RuntimeVersion profile.RuntimeVersion = baseline.RuntimeVersion
profile.Version = baseline.Version profile.version = baseline.version
profile.HasVersion = baseline.HasVersion profile.hasVersion = baseline.hasVersion
} }
return profile return profile
} }
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) { func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (ClaudeDeviceProfile, bool) {
if headers == nil { if headers == nil {
return claudeDeviceProfile{}, false return ClaudeDeviceProfile{}, false
} }
userAgent := strings.TrimSpace(headers.Get("User-Agent")) userAgent := strings.TrimSpace(headers.Get("User-Agent"))
version, ok := parseClaudeCLIVersion(userAgent) version, ok := parseClaudeCLIVersion(userAgent)
if !ok { if !ok {
return claudeDeviceProfile{}, false return ClaudeDeviceProfile{}, false
} }
baseline := defaultClaudeDeviceProfile(cfg) baseline := defaultClaudeDeviceProfile(cfg)
profile := claudeDeviceProfile{ profile := ClaudeDeviceProfile{
UserAgent: userAgent, UserAgent: userAgent,
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion), PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion), RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS), OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch), Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
Version: version, version: version,
HasVersion: true, hasVersion: true,
} }
return profile, true return profile, true
} }
@@ -263,7 +277,7 @@ func purgeExpiredClaudeDeviceProfiles() {
claudeDeviceProfileCacheMu.Unlock() claudeDeviceProfileCacheMu.Unlock()
} }
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile { func ResolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) ClaudeDeviceProfile {
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup) claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey) cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
@@ -283,8 +297,8 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
claudeDeviceProfileCacheMu.RUnlock() claudeDeviceProfileCacheMu.RUnlock()
if hasCandidate { if hasCandidate {
if claudeDeviceProfileBeforeCandidateStore != nil { if ClaudeDeviceProfileBeforeCandidateStore != nil {
claudeDeviceProfileBeforeCandidateStore(candidate) ClaudeDeviceProfileBeforeCandidateStore(candidate)
} }
claudeDeviceProfileCacheMu.Lock() claudeDeviceProfileCacheMu.Lock()
@@ -324,7 +338,7 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
return baseline return baseline
} }
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) { func ApplyClaudeDeviceProfileHeaders(r *http.Request, profile ClaudeDeviceProfile) {
if r == nil { if r == nil {
return return
} }
@@ -344,7 +358,17 @@ func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfil
r.Header.Set("X-Stainless-Arch", profile.Arch) r.Header.Set("X-Stainless-Arch", profile.Arch)
} }
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) { // DefaultClaudeVersion returns the version string (e.g. "2.1.63") from the
// current baseline device profile. It extracts the version from the User-Agent.
func DefaultClaudeVersion(cfg *config.Config) string {
profile := defaultClaudeDeviceProfile(cfg)
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
return strconv.Itoa(version.major) + "." + strconv.Itoa(version.minor) + "." + strconv.Itoa(version.patch)
}
return "2.1.63"
}
func ApplyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
if r == nil { if r == nil {
return return
} }

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"regexp" "regexp"
@@ -18,9 +18,9 @@ type SensitiveWordMatcher struct {
regex *regexp.Regexp regex *regexp.Regexp
} }
// buildSensitiveWordMatcher compiles a regex from the word list. // BuildSensitiveWordMatcher compiles a regex from the word list.
// Words are sorted by length (longest first) for proper matching. // Words are sorted by length (longest first) for proper matching.
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher { func BuildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
if len(words) == 0 { if len(words) == 0 {
return nil return nil
} }
@@ -81,9 +81,9 @@ func (m *SensitiveWordMatcher) obfuscateText(text string) string {
return m.regex.ReplaceAllStringFunc(text, obfuscateWord) return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
} }
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words // ObfuscateSensitiveWords processes the payload and obfuscates sensitive words
// in system blocks and message content. // in system blocks and message content.
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte { func ObfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
if matcher == nil || matcher.regex == nil { if matcher == nil || matcher.regex == nil {
return payload return payload
} }

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"crypto/rand" "crypto/rand"
@@ -28,9 +28,17 @@ func isValidUserID(userID string) bool {
return userIDPattern.MatchString(userID) return userIDPattern.MatchString(userID)
} }
// shouldCloak determines if request should be cloaked based on config and client User-Agent. func GenerateFakeUserID() string {
return generateFakeUserID()
}
func IsValidUserID(userID string) bool {
return isValidUserID(userID)
}
// ShouldCloak determines if request should be cloaked based on config and client User-Agent.
// Returns true if cloaking should be applied. // Returns true if cloaking should be applied.
func shouldCloak(cloakMode string, userAgent string) bool { func ShouldCloak(cloakMode string, userAgent string) bool {
switch strings.ToLower(cloakMode) { switch strings.ToLower(cloakMode) {
case "always": case "always":
return true return true

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"bytes" "bytes"
@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"html" "html"
"net/http" "net/http"
"net/url"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -19,13 +20,14 @@ import (
) )
const ( const (
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS" apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
apiRequestKey = "API_REQUEST" apiRequestKey = "API_REQUEST"
apiResponseKey = "API_RESPONSE" apiResponseKey = "API_RESPONSE"
apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE"
) )
// upstreamRequestLog captures the outbound upstream request details for logging. // UpstreamRequestLog captures the outbound upstream request details for logging.
type upstreamRequestLog struct { type UpstreamRequestLog struct {
URL string URL string
Method string Method string
Headers http.Header Headers http.Header
@@ -46,11 +48,12 @@ type upstreamAttempt struct {
headersWritten bool headersWritten bool
bodyStarted bool bodyStarted bool
bodyHasContent bool bodyHasContent bool
prevWasSSEEvent bool
errorWritten bool errorWritten bool
} }
// recordAPIRequest stores the upstream request metadata in Gin context for request logging. // RecordAPIRequest stores the upstream request metadata in Gin context for request logging.
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) { func RecordAPIRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
if cfg == nil || !cfg.RequestLog { if cfg == nil || !cfg.RequestLog {
return return
} }
@@ -96,8 +99,8 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
updateAggregatedRequest(ginCtx, attempts) updateAggregatedRequest(ginCtx, attempts)
} }
// recordAPIResponseMetadata captures upstream response status/header information for the latest attempt. // RecordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) { func RecordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
if cfg == nil || !cfg.RequestLog { if cfg == nil || !cfg.RequestLog {
return return
} }
@@ -122,8 +125,8 @@ func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status i
updateAggregatedResponse(ginCtx, attempts) updateAggregatedResponse(ginCtx, attempts)
} }
// recordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available. // RecordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) { func RecordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
if cfg == nil || !cfg.RequestLog || err == nil { if cfg == nil || !cfg.RequestLog || err == nil {
return return
} }
@@ -147,8 +150,8 @@ func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error)
updateAggregatedResponse(ginCtx, attempts) updateAggregatedResponse(ginCtx, attempts)
} }
// appendAPIResponseChunk appends an upstream response chunk to Gin context for request logging. // AppendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) { func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
if cfg == nil || !cfg.RequestLog { if cfg == nil || !cfg.RequestLog {
return return
} }
@@ -173,15 +176,157 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
attempt.response.WriteString("Body:\n") attempt.response.WriteString("Body:\n")
attempt.bodyStarted = true attempt.bodyStarted = true
} }
currentChunkIsSSEEvent := bytes.HasPrefix(data, []byte("event:"))
currentChunkIsSSEData := bytes.HasPrefix(data, []byte("data:"))
if attempt.bodyHasContent { if attempt.bodyHasContent {
attempt.response.WriteString("\n\n") separator := "\n\n"
if attempt.prevWasSSEEvent && currentChunkIsSSEData {
separator = "\n"
}
attempt.response.WriteString(separator)
} }
attempt.response.WriteString(string(data)) attempt.response.WriteString(string(data))
attempt.bodyHasContent = true attempt.bodyHasContent = true
attempt.prevWasSSEEvent = currentChunkIsSSEEvent
updateAggregatedResponse(ginCtx, attempts) updateAggregatedResponse(ginCtx, attempts)
} }
// RecordAPIWebsocketRequest stores an upstream websocket request event in Gin context.
func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.request\n")
if info.URL != "" {
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
}
if auth := formatAuthInfo(info); auth != "" {
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
}
builder.WriteString("Headers:\n")
writeHeaders(builder, info.Headers)
builder.WriteString("\nBody:\n")
if len(info.Body) > 0 {
builder.Write(info.Body)
} else {
builder.WriteString("<empty>")
}
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.handshake\n")
if status > 0 {
builder.WriteString(fmt.Sprintf("Status: %d\n", status))
}
builder.WriteString("Headers:\n")
writeHeaders(builder, headers)
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
RecordAPIRequest(ctx, cfg, info)
RecordAPIResponseMetadata(ctx, cfg, status, headers)
AppendAPIResponseChunk(ctx, cfg, body)
}
// WebsocketUpgradeRequestURL converts a websocket URL back to its HTTP handshake URL for logging.
func WebsocketUpgradeRequestURL(rawURL string) string {
trimmedURL := strings.TrimSpace(rawURL)
if trimmedURL == "" {
return ""
}
parsed, err := url.Parse(trimmedURL)
if err != nil {
return trimmedURL
}
switch strings.ToLower(parsed.Scheme) {
case "ws":
parsed.Scheme = "http"
case "wss":
parsed.Scheme = "https"
}
return parsed.String()
}
// AppendAPIWebsocketResponse stores an upstream websocket response frame in Gin context.
func AppendAPIWebsocketResponse(ctx context.Context, cfg *config.Config, payload []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
data := bytes.TrimSpace(payload)
if len(data) == 0 {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
markAPIResponseTimestamp(ginCtx)
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.response\n")
builder.Write(data)
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketError stores an upstream websocket error event in Gin context.
func RecordAPIWebsocketError(ctx context.Context, cfg *config.Config, stage string, err error) {
if cfg == nil || !cfg.RequestLog || err == nil {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
markAPIResponseTimestamp(ginCtx)
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.error\n")
if trimmed := strings.TrimSpace(stage); trimmed != "" {
builder.WriteString(fmt.Sprintf("Stage: %s\n", trimmed))
}
builder.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
func ginContextFrom(ctx context.Context) *gin.Context { func ginContextFrom(ctx context.Context) *gin.Context {
ginCtx, _ := ctx.Value("gin").(*gin.Context) ginCtx, _ := ctx.Value("gin").(*gin.Context)
return ginCtx return ginCtx
@@ -259,6 +404,40 @@ func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt)
ginCtx.Set(apiResponseKey, []byte(builder.String())) ginCtx.Set(apiResponseKey, []byte(builder.String()))
} }
func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) {
if ginCtx == nil {
return
}
data := bytes.TrimSpace(chunk)
if len(data) == 0 {
return
}
if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
combined := make([]byte, 0, len(existingBytes)+len(data)+2)
combined = append(combined, existingBytes...)
if !bytes.HasSuffix(existingBytes, []byte("\n")) {
combined = append(combined, '\n')
}
combined = append(combined, '\n')
combined = append(combined, data...)
ginCtx.Set(apiWebsocketTimelineKey, combined)
return
}
}
ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data))
}
func markAPIResponseTimestamp(ginCtx *gin.Context) {
if ginCtx == nil {
return
}
if _, exists := ginCtx.Get("API_RESPONSE_TIMESTAMP"); exists {
return
}
ginCtx.Set("API_RESPONSE_TIMESTAMP", time.Now())
}
func writeHeaders(builder *strings.Builder, headers http.Header) { func writeHeaders(builder *strings.Builder, headers http.Header) {
if builder == nil { if builder == nil {
return return
@@ -285,7 +464,7 @@ func writeHeaders(builder *strings.Builder, headers http.Header) {
} }
} }
func formatAuthInfo(info upstreamRequestLog) string { func formatAuthInfo(info UpstreamRequestLog) string {
var parts []string var parts []string
if trimmed := strings.TrimSpace(info.Provider); trimmed != "" { if trimmed := strings.TrimSpace(info.Provider); trimmed != "" {
parts = append(parts, fmt.Sprintf("provider=%s", trimmed)) parts = append(parts, fmt.Sprintf("provider=%s", trimmed))
@@ -321,7 +500,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
return strings.Join(parts, ", ") return strings.Join(parts, ", ")
} }
func summarizeErrorBody(contentType string, body []byte) string { func SummarizeErrorBody(contentType string, body []byte) string {
isHTML := strings.Contains(strings.ToLower(contentType), "text/html") isHTML := strings.Contains(strings.ToLower(contentType), "text/html")
if !isHTML { if !isHTML {
trimmed := bytes.TrimSpace(bytes.ToLower(body)) trimmed := bytes.TrimSpace(bytes.ToLower(body))
@@ -379,7 +558,7 @@ func extractJSONErrorMessage(body []byte) string {
// logWithRequestID returns a logrus Entry with request_id field populated from context. // logWithRequestID returns a logrus Entry with request_id field populated from context.
// If no request ID is found in context, it returns the standard logger. // If no request ID is found in context, it returns the standard logger.
func logWithRequestID(ctx context.Context) *log.Entry { func LogWithRequestID(ctx context.Context) *log.Entry {
if ctx == nil { if ctx == nil {
return log.NewEntry(log.StandardLogger()) return log.NewEntry(log.StandardLogger())
} }

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"encoding/json" "encoding/json"
@@ -11,12 +11,12 @@ import (
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter // ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI) // paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied. Defaults are checked // and restricts matches to the given protocol when supplied. Defaults are checked
// against the original payload when provided. requestedModel carries the client-visible // against the original payload when provided. requestedModel carries the client-visible
// model name before alias resolution so payload rules can target aliases precisely. // model name before alias resolution so payload rules can target aliases precisely.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte { func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
if cfg == nil || len(payload) == 0 { if cfg == nil || len(payload) == 0 {
return payload return payload
} }
@@ -244,7 +244,7 @@ func payloadRawValue(value any) ([]byte, bool) {
} }
} }
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string { func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
fallback = strings.TrimSpace(fallback) fallback = strings.TrimSpace(fallback)
if len(opts.Metadata) == 0 { if len(opts.Metadata) == 0 {
return fallback return fallback

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"context" "context"
@@ -19,7 +19,7 @@ var (
httpClientCacheMutex sync.RWMutex httpClientCacheMutex sync.RWMutex
) )
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: // NewProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// 1. Use auth.ProxyURL if configured (highest priority) // 1. Use auth.ProxyURL if configured (highest priority)
// 2. Use cfg.ProxyURL if auth proxy is not configured // 2. Use cfg.ProxyURL if auth proxy is not configured
// 3. Use RoundTripper from context if neither are configured // 3. Use RoundTripper from context if neither are configured
@@ -34,7 +34,7 @@ var (
// //
// Returns: // Returns:
// - *http.Client: An HTTP client with configured proxy or transport // - *http.Client: An HTTP client with configured proxy or transport
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
// Priority 1: Use auth.ProxyURL if configured // Priority 1: Use auth.ProxyURL if configured
var proxyURL string var proxyURL string
if auth != nil { if auth != nil {
@@ -46,23 +46,18 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
proxyURL = strings.TrimSpace(cfg.ProxyURL) proxyURL = strings.TrimSpace(cfg.ProxyURL)
} }
// Build cache key from proxy URL (empty string for no proxy) // If we have a proxy URL configured, try cache first to reuse TCP/TLS connections.
cacheKey := proxyURL if proxyURL != "" {
httpClientCacheMutex.RLock()
// Check cache first if cachedClient, ok := httpClientCache[proxyURL]; ok {
httpClientCacheMutex.RLock() httpClientCacheMutex.RUnlock()
if cachedClient, ok := httpClientCache[cacheKey]; ok { if timeout > 0 {
httpClientCacheMutex.RUnlock() return &http.Client{Transport: cachedClient.Transport, Timeout: timeout}
// Return a wrapper with the requested timeout but shared transport
if timeout > 0 {
return &http.Client{
Transport: cachedClient.Transport,
Timeout: timeout,
} }
return cachedClient
} }
return cachedClient httpClientCacheMutex.RUnlock()
} }
httpClientCacheMutex.RUnlock()
// Create new client // Create new client
httpClient := &http.Client{} httpClient := &http.Client{}
@@ -77,7 +72,7 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = transport httpClient.Transport = transport
// Cache the client // Cache the client
httpClientCacheMutex.Lock() httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient httpClientCache[proxyURL] = httpClient
httpClientCacheMutex.Unlock() httpClientCacheMutex.Unlock()
return httpClient return httpClient
} }
@@ -90,13 +85,6 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = rt httpClient.Transport = rt
} }
// Cache the client for no-proxy case
if proxyURL == "" {
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
}
return httpClient return httpClient
} }

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"context" "context"
@@ -13,7 +13,7 @@ import (
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) { func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel() t.Parallel()
client := newProxyAwareHTTPClient( client := NewProxyAwareHTTPClient(
context.Background(), context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}}, &config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"}, &cliproxyauth.Auth{ProxyURL: "direct"},

View File

@@ -0,0 +1,92 @@
package helps
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
"github.com/google/uuid"
)
type sessionIDCacheEntry struct {
value string
expire time.Time
}
var (
sessionIDCache = make(map[string]sessionIDCacheEntry)
sessionIDCacheMu sync.RWMutex
sessionIDCacheCleanupOnce sync.Once
)
const (
sessionIDTTL = time.Hour
sessionIDCacheCleanupPeriod = 15 * time.Minute
)
func startSessionIDCacheCleanup() {
go func() {
ticker := time.NewTicker(sessionIDCacheCleanupPeriod)
defer ticker.Stop()
for range ticker.C {
purgeExpiredSessionIDs()
}
}()
}
func purgeExpiredSessionIDs() {
now := time.Now()
sessionIDCacheMu.Lock()
for key, entry := range sessionIDCache {
if !entry.expire.After(now) {
delete(sessionIDCache, key)
}
}
sessionIDCacheMu.Unlock()
}
func sessionIDCacheKey(apiKey string) string {
sum := sha256.Sum256([]byte(apiKey))
return hex.EncodeToString(sum[:])
}
// CachedSessionID returns a stable session UUID per apiKey, refreshing the TTL on each access.
func CachedSessionID(apiKey string) string {
if apiKey == "" {
return uuid.New().String()
}
sessionIDCacheCleanupOnce.Do(startSessionIDCacheCleanup)
key := sessionIDCacheKey(apiKey)
now := time.Now()
sessionIDCacheMu.RLock()
entry, ok := sessionIDCache[key]
valid := ok && entry.value != "" && entry.expire.After(now)
sessionIDCacheMu.RUnlock()
if valid {
sessionIDCacheMu.Lock()
entry = sessionIDCache[key]
if entry.value != "" && entry.expire.After(now) {
entry.expire = now.Add(sessionIDTTL)
sessionIDCache[key] = entry
sessionIDCacheMu.Unlock()
return entry.value
}
sessionIDCacheMu.Unlock()
}
newID := uuid.New().String()
sessionIDCacheMu.Lock()
entry, ok = sessionIDCache[key]
if !ok || entry.value == "" || !entry.expire.After(now) {
entry.value = newID
}
entry.expire = now.Add(sessionIDTTL)
sessionIDCache[key] = entry
sessionIDCacheMu.Unlock()
return entry.value
}

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"

View File

@@ -1,9 +1,7 @@
package executor package helps
import ( import (
"fmt" "fmt"
"regexp"
"strconv"
"strings" "strings"
"sync" "sync"
@@ -11,100 +9,80 @@ import (
"github.com/tiktoken-go/tokenizer" "github.com/tiktoken-go/tokenizer"
) )
// tokenizerCache stores tokenizer instances to avoid repeated creation // tokenizerCache stores tokenizer instances to avoid repeated creation.
var tokenizerCache sync.Map var tokenizerCache sync.Map
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models type adjustedTokenizer struct {
// where tiktoken may not accurately estimate token counts (e.g., Claude models) tokenizer.Codec
type TokenizerWrapper struct { adjustmentFactor float64
Codec tokenizer.Codec
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
} }
// Count returns the token count with adjustment factor applied func (tw *adjustedTokenizer) Count(text string) (int, error) {
func (tw *TokenizerWrapper) Count(text string) (int, error) {
count, err := tw.Codec.Count(text) count, err := tw.Codec.Count(text)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 { if tw.adjustmentFactor > 0 && tw.adjustmentFactor != 1.0 {
return int(float64(count) * tw.AdjustmentFactor), nil return int(float64(count) * tw.adjustmentFactor), nil
} }
return count, nil return count, nil
} }
// getTokenizer returns a cached tokenizer for the given model. // TokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
// This improves performance by avoiding repeated tokenizer creation. // For Claude-like models, it applies an adjustment factor since tiktoken may underestimate token counts.
func getTokenizer(model string) (*TokenizerWrapper, error) { func TokenizerForModel(model string) (tokenizer.Codec, error) {
// Check cache first sanitized := strings.ToLower(strings.TrimSpace(model))
if cached, ok := tokenizerCache.Load(model); ok { if cached, ok := tokenizerCache.Load(sanitized); ok {
return cached.(*TokenizerWrapper), nil return cached.(tokenizer.Codec), nil
} }
// Cache miss, create new tokenizer enc, err := tokenizerForModel(sanitized)
wrapper, err := tokenizerForModel(model)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Store in cache (use LoadOrStore to handle race conditions) actual, _ := tokenizerCache.LoadOrStore(sanitized, enc)
actual, _ := tokenizerCache.LoadOrStore(model, wrapper) return actual.(tokenizer.Codec), nil
return actual.(*TokenizerWrapper), nil
} }
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id. func tokenizerForModel(sanitized string) (tokenizer.Codec, error) {
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate. if sanitized == "" {
func tokenizerForModel(model string) (*TokenizerWrapper, error) { return tokenizer.Get(tokenizer.Cl100kBase)
sanitized := strings.ToLower(strings.TrimSpace(model)) }
// Claude models use cl100k_base with 1.1 adjustment factor // Claude models use cl100k_base with an adjustment factor because tiktoken may underestimate.
// because tiktoken may underestimate Claude's actual token count
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") { if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
enc, err := tokenizer.Get(tokenizer.Cl100kBase) enc, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil return &adjustedTokenizer{Codec: enc, adjustmentFactor: 1.1}, nil
} }
var enc tokenizer.Codec
var err error
switch { switch {
case sanitized == "":
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5.2"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5.1"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5"): case strings.HasPrefix(sanitized, "gpt-5"):
enc, err = tokenizer.ForModel(tokenizer.GPT5) return tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"): case strings.HasPrefix(sanitized, "gpt-4.1"):
enc, err = tokenizer.ForModel(tokenizer.GPT41) return tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"): case strings.HasPrefix(sanitized, "gpt-4o"):
enc, err = tokenizer.ForModel(tokenizer.GPT4o) return tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"): case strings.HasPrefix(sanitized, "gpt-4"):
enc, err = tokenizer.ForModel(tokenizer.GPT4) return tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"): case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo) return tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"): case strings.HasPrefix(sanitized, "o1"):
enc, err = tokenizer.ForModel(tokenizer.O1) return tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"): case strings.HasPrefix(sanitized, "o3"):
enc, err = tokenizer.ForModel(tokenizer.O3) return tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"): case strings.HasPrefix(sanitized, "o4"):
enc, err = tokenizer.ForModel(tokenizer.O4Mini) return tokenizer.ForModel(tokenizer.O4Mini)
default: default:
enc, err = tokenizer.Get(tokenizer.O200kBase) return tokenizer.Get(tokenizer.O200kBase)
} }
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
} }
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads. // CountOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) { func CountOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
if enc == nil { if enc == nil {
return 0, fmt.Errorf("encoder is nil") return 0, fmt.Errorf("encoder is nil")
} }
@@ -128,22 +106,15 @@ func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
return 0, nil return 0, nil
} }
// Count text tokens
count, err := enc.Count(joined) count, err := enc.Count(joined)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return int64(count), nil
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
} }
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads. // CountClaudeChatTokens approximates prompt tokens for Claude API chat payloads.
// This handles Claude's message format with system, messages, and tools. func CountClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
// Image tokens are estimated based on image dimensions when available.
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
if enc == nil { if enc == nil {
return 0, fmt.Errorf("encoder is nil") return 0, fmt.Errorf("encoder is nil")
} }
@@ -153,185 +124,25 @@ func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
root := gjson.ParseBytes(payload) root := gjson.ParseBytes(payload)
segments := make([]string, 0, 32) segments := make([]string, 0, 32)
imageTokens := 0
// Collect system prompt (can be string or array of content blocks) collectClaudeContent(root.Get("system"), &segments, &imageTokens)
collectClaudeSystem(root.Get("system"), &segments) collectClaudeMessages(root.Get("messages"), &segments, &imageTokens)
// Collect messages
collectClaudeMessages(root.Get("messages"), &segments)
// Collect tools
collectClaudeTools(root.Get("tools"), &segments) collectClaudeTools(root.Get("tools"), &segments)
joined := strings.TrimSpace(strings.Join(segments, "\n")) joined := strings.TrimSpace(strings.Join(segments, "\n"))
if joined == "" { if joined == "" {
return 0, nil return int64(imageTokens), nil
} }
// Count text tokens
count, err := enc.Count(joined) count, err := enc.Count(joined)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return int64(count + imageTokens), nil
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
} }
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens // BuildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`) func BuildOpenAIUsageJSON(count int64) []byte {
// extractImageTokens extracts image token estimates from placeholder text.
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
func extractImageTokens(text string) int {
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
total := 0
for _, match := range matches {
if len(match) > 1 {
if tokens, err := strconv.Atoi(match[1]); err == nil {
total += tokens
}
}
}
return total
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image)
return 1000
}
tokens := int(width * height / 750)
// Apply bounds
if tokens < 85 {
tokens = 85
}
if tokens > 1590 {
tokens = 1590
}
return tokens
}
// collectClaudeSystem extracts text from Claude's system field.
// System can be a string or an array of content blocks.
func collectClaudeSystem(system gjson.Result, segments *[]string) {
if !system.Exists() {
return
}
if system.Type == gjson.String {
addIfNotEmpty(segments, system.String())
return
}
if system.IsArray() {
system.ForEach(func(_, block gjson.Result) bool {
blockType := block.Get("type").String()
if blockType == "text" || blockType == "" {
addIfNotEmpty(segments, block.Get("text").String())
}
// Also handle plain string blocks
if block.Type == gjson.String {
addIfNotEmpty(segments, block.String())
}
return true
})
}
}
// collectClaudeMessages extracts text from Claude's messages array.
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments)
return true
})
}
// collectClaudeContent extracts text from Claude's content field.
// Content can be a string or an array of content blocks.
// For images, estimates token count based on dimensions when available.
func collectClaudeContent(content gjson.Result, segments *[]string) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
// Estimate image tokens based on dimensions if available
source := part.Get("source")
if source.Exists() {
width := source.Get("width").Float()
height := source.Get("height").Float()
if width > 0 && height > 0 {
tokens := estimateImageTokens(width, height)
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
} else {
// No dimensions available, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
} else {
// No source info, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
// For unknown types, try to extract any text content
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
}
}
// collectClaudeTools extracts text from Claude's tools array.
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
func buildOpenAIUsageJSON(count int64) []byte {
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count)) return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
} }
@@ -390,6 +201,10 @@ func collectOpenAIContent(content gjson.Result, segments *[]string) {
} }
} }
func CollectOpenAIContent(content gjson.Result, segments *[]string) {
collectOpenAIContent(content, segments)
}
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) { func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
if !calls.Exists() || !calls.IsArray() { if !calls.Exists() || !calls.IsArray() {
return return
@@ -487,6 +302,98 @@ func appendToolPayload(tool gjson.Result, segments *[]string) {
} }
} }
func collectClaudeMessages(messages gjson.Result, segments *[]string, imageTokens *int) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments, imageTokens)
return true
})
}
func collectClaudeContent(content gjson.Result, segments *[]string, imageTokens *int) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
source := part.Get("source")
width := source.Get("width").Float()
height := source.Get("height").Float()
if imageTokens != nil {
*imageTokens += estimateImageTokens(width, height)
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments, imageTokens)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
return
}
if content.Type == gjson.JSON {
addIfNotEmpty(segments, content.Raw)
}
}
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image).
return 1000
}
tokens := int(width * height / 750)
if tokens < 85 {
return 85
}
if tokens > 1590 {
return 1590
}
return tokens
}
func addIfNotEmpty(segments *[]string, value string) { func addIfNotEmpty(segments *[]string, value string) {
if segments == nil { if segments == nil {
return return

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"bytes" "bytes"
@@ -15,7 +15,7 @@ import (
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
type usageReporter struct { type UsageReporter struct {
provider string provider string
model string model string
authID string authID string
@@ -26,9 +26,9 @@ type usageReporter struct {
once sync.Once once sync.Once
} }
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter { func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
apiKey := apiKeyFromContext(ctx) apiKey := APIKeyFromContext(ctx)
reporter := &usageReporter{ reporter := &UsageReporter{
provider: provider, provider: provider,
model: model, model: model,
requestedAt: time.Now(), requestedAt: time.Now(),
@@ -42,24 +42,24 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox
return reporter return reporter
} }
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) { func (r *UsageReporter) Publish(ctx context.Context, detail usage.Detail) {
r.publishWithOutcome(ctx, detail, false) r.publishWithOutcome(ctx, detail, false)
} }
func (r *usageReporter) publishFailure(ctx context.Context) { func (r *UsageReporter) PublishFailure(ctx context.Context) {
r.publishWithOutcome(ctx, usage.Detail{}, true) r.publishWithOutcome(ctx, usage.Detail{}, true)
} }
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) { func (r *UsageReporter) TrackFailure(ctx context.Context, errPtr *error) {
if r == nil || errPtr == nil { if r == nil || errPtr == nil {
return return
} }
if *errPtr != nil { if *errPtr != nil {
r.publishFailure(ctx) r.PublishFailure(ctx)
} }
} }
func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) { func (r *UsageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
if r == nil { if r == nil {
return return
} }
@@ -69,9 +69,6 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
detail.TotalTokens = total detail.TotalTokens = total
} }
} }
if detail.InputTokens == 0 && detail.OutputTokens == 0 && detail.ReasoningTokens == 0 && detail.CachedTokens == 0 && detail.TotalTokens == 0 && !failed {
return
}
r.once.Do(func() { r.once.Do(func() {
usage.PublishRecord(ctx, r.buildRecord(detail, failed)) usage.PublishRecord(ctx, r.buildRecord(detail, failed))
}) })
@@ -81,7 +78,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
// It is safe to call multiple times; only the first call wins due to once.Do. // It is safe to call multiple times; only the first call wins due to once.Do.
// This is used to ensure request counting even when upstream responses do not // This is used to ensure request counting even when upstream responses do not
// include any usage fields (tokens), especially for streaming paths. // include any usage fields (tokens), especially for streaming paths.
func (r *usageReporter) ensurePublished(ctx context.Context) { func (r *UsageReporter) EnsurePublished(ctx context.Context) {
if r == nil { if r == nil {
return return
} }
@@ -90,7 +87,7 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
}) })
} }
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record { func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
if r == nil { if r == nil {
return usage.Record{Detail: detail, Failed: failed} return usage.Record{Detail: detail, Failed: failed}
} }
@@ -108,7 +105,7 @@ func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Reco
} }
} }
func (r *usageReporter) latency() time.Duration { func (r *UsageReporter) latency() time.Duration {
if r == nil || r.requestedAt.IsZero() { if r == nil || r.requestedAt.IsZero() {
return 0 return 0
} }
@@ -119,7 +116,7 @@ func (r *usageReporter) latency() time.Duration {
return latency return latency
} }
func apiKeyFromContext(ctx context.Context) string { func APIKeyFromContext(ctx context.Context) string {
if ctx == nil { if ctx == nil {
return "" return ""
} }
@@ -184,7 +181,7 @@ func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
return "" return ""
} }
func parseCodexUsage(data []byte) (usage.Detail, bool) { func ParseCodexUsage(data []byte) (usage.Detail, bool) {
usageNode := gjson.ParseBytes(data).Get("response.usage") usageNode := gjson.ParseBytes(data).Get("response.usage")
if !usageNode.Exists() { if !usageNode.Exists() {
return usage.Detail{}, false return usage.Detail{}, false
@@ -203,7 +200,7 @@ func parseCodexUsage(data []byte) (usage.Detail, bool) {
return detail, true return detail, true
} }
func parseOpenAIUsage(data []byte) usage.Detail { func ParseOpenAIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage") usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() { if !usageNode.Exists() {
return usage.Detail{} return usage.Detail{}
@@ -238,7 +235,7 @@ func parseOpenAIUsage(data []byte) usage.Detail {
return detail return detail
} }
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) { func ParseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line) payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) { if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false return usage.Detail{}, false
@@ -247,59 +244,40 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
if !usageNode.Exists() { if !usageNode.Exists() {
return usage.Detail{}, false return usage.Detail{}, false
} }
inputNode := usageNode.Get("prompt_tokens")
if !inputNode.Exists() {
inputNode = usageNode.Get("input_tokens")
}
outputNode := usageNode.Get("completion_tokens")
if !outputNode.Exists() {
outputNode = usageNode.Get("output_tokens")
}
detail := usage.Detail{ detail := usage.Detail{
InputTokens: usageNode.Get("prompt_tokens").Int(), InputTokens: inputNode.Int(),
OutputTokens: usageNode.Get("completion_tokens").Int(), OutputTokens: outputNode.Int(),
TotalTokens: usageNode.Get("total_tokens").Int(), TotalTokens: usageNode.Get("total_tokens").Int(),
} }
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
if !cached.Exists() {
cached = usageNode.Get("input_tokens_details.cached_tokens")
}
if cached.Exists() {
detail.CachedTokens = cached.Int() detail.CachedTokens = cached.Int()
} }
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
if !reasoning.Exists() {
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
}
if reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int() detail.ReasoningTokens = reasoning.Int()
} }
return detail, true return detail, true
} }
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail { func ParseClaudeUsage(data []byte) usage.Detail {
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
}
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
}
return parseOpenAIResponsesUsageDetail(usageNode)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
usageNode := gjson.GetBytes(payload, "usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
return parseOpenAIResponsesUsageDetail(usageNode), true
}
func parseClaudeUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage") usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() { if !usageNode.Exists() {
return usage.Detail{} return usage.Detail{}
@@ -317,7 +295,7 @@ func parseClaudeUsage(data []byte) usage.Detail {
return detail return detail
} }
func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) { func ParseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line) payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) { if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false return usage.Detail{}, false
@@ -352,7 +330,7 @@ func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
return detail return detail
} }
func parseGeminiCLIUsage(data []byte) usage.Detail { func ParseGeminiCLIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data) usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata") node := usageNode.Get("response.usageMetadata")
if !node.Exists() { if !node.Exists() {
@@ -364,7 +342,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node) return parseGeminiFamilyUsageDetail(node)
} }
func parseGeminiUsage(data []byte) usage.Detail { func ParseGeminiUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data) usageNode := gjson.ParseBytes(data)
node := usageNode.Get("usageMetadata") node := usageNode.Get("usageMetadata")
if !node.Exists() { if !node.Exists() {
@@ -376,7 +354,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node) return parseGeminiFamilyUsageDetail(node)
} }
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) { func ParseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line) payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) { if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false return usage.Detail{}, false
@@ -391,7 +369,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
return parseGeminiFamilyUsageDetail(node), true return parseGeminiFamilyUsageDetail(node), true
} }
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) { func ParseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line) payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) { if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false return usage.Detail{}, false
@@ -406,7 +384,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
return parseGeminiFamilyUsageDetail(node), true return parseGeminiFamilyUsageDetail(node), true
} }
func parseAntigravityUsage(data []byte) usage.Detail { func ParseAntigravityUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data) usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata") node := usageNode.Get("response.usageMetadata")
if !node.Exists() { if !node.Exists() {
@@ -421,7 +399,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node) return parseGeminiFamilyUsageDetail(node)
} }
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) { func ParseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line) payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) { if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false return usage.Detail{}, false
@@ -590,6 +568,10 @@ func isStopChunkWithoutUsage(jsonBytes []byte) bool {
return !hasUsageMetadata(jsonBytes) return !hasUsageMetadata(jsonBytes)
} }
func JSONPayload(line []byte) []byte {
return jsonPayload(line)
}
func jsonPayload(line []byte) []byte { func jsonPayload(line []byte) []byte {
trimmed := bytes.TrimSpace(line) trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 { if len(trimmed) == 0 {

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"testing" "testing"
@@ -9,7 +9,7 @@ import (
func TestParseOpenAIUsageChatCompletions(t *testing.T) { func TestParseOpenAIUsageChatCompletions(t *testing.T) {
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`) data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
detail := parseOpenAIUsage(data) detail := ParseOpenAIUsage(data)
if detail.InputTokens != 1 { if detail.InputTokens != 1 {
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1) t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
} }
@@ -29,7 +29,7 @@ func TestParseOpenAIUsageChatCompletions(t *testing.T) {
func TestParseOpenAIUsageResponses(t *testing.T) { func TestParseOpenAIUsageResponses(t *testing.T) {
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`) data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
detail := parseOpenAIUsage(data) detail := ParseOpenAIUsage(data)
if detail.InputTokens != 10 { if detail.InputTokens != 10 {
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10) t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
} }
@@ -48,7 +48,7 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
} }
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) { func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
reporter := &usageReporter{ reporter := &UsageReporter{
provider: "openai", provider: "openai",
model: "gpt-5.4", model: "gpt-5.4",
requestedAt: time.Now().Add(-1500 * time.Millisecond), requestedAt: time.Now().Add(-1500 * time.Millisecond),

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"crypto/sha256" "crypto/sha256"
@@ -49,7 +49,7 @@ func userIDCacheKey(apiKey string) string {
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
} }
func cachedUserID(apiKey string) string { func CachedUserID(apiKey string) string {
if apiKey == "" { if apiKey == "" {
return generateFakeUserID() return generateFakeUserID()
} }

View File

@@ -1,4 +1,4 @@
package executor package helps
import ( import (
"testing" "testing"
@@ -14,8 +14,8 @@ func resetUserIDCache() {
func TestCachedUserID_ReusesWithinTTL(t *testing.T) { func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
resetUserIDCache() resetUserIDCache()
first := cachedUserID("api-key-1") first := CachedUserID("api-key-1")
second := cachedUserID("api-key-1") second := CachedUserID("api-key-1")
if first == "" { if first == "" {
t.Fatal("expected generated user_id to be non-empty") t.Fatal("expected generated user_id to be non-empty")
@@ -28,7 +28,7 @@ func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) { func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
resetUserIDCache() resetUserIDCache()
expiredID := cachedUserID("api-key-expired") expiredID := CachedUserID("api-key-expired")
cacheKey := userIDCacheKey("api-key-expired") cacheKey := userIDCacheKey("api-key-expired")
userIDCacheMu.Lock() userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{ userIDCache[cacheKey] = userIDCacheEntry{
@@ -37,7 +37,7 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
} }
userIDCacheMu.Unlock() userIDCacheMu.Unlock()
newID := cachedUserID("api-key-expired") newID := CachedUserID("api-key-expired")
if newID == expiredID { if newID == expiredID {
t.Fatalf("expected expired user_id to be replaced, got %q", newID) t.Fatalf("expected expired user_id to be replaced, got %q", newID)
} }
@@ -49,8 +49,8 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) { func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
resetUserIDCache() resetUserIDCache()
first := cachedUserID("api-key-1") first := CachedUserID("api-key-1")
second := cachedUserID("api-key-2") second := CachedUserID("api-key-2")
if first == second { if first == second {
t.Fatalf("expected different API keys to have different user_ids, got %q", first) t.Fatalf("expected different API keys to have different user_ids, got %q", first)
@@ -61,7 +61,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
resetUserIDCache() resetUserIDCache()
key := "api-key-renew" key := "api-key-renew"
id := cachedUserID(key) id := CachedUserID(key)
cacheKey := userIDCacheKey(key) cacheKey := userIDCacheKey(key)
soon := time.Now() soon := time.Now()
@@ -72,7 +72,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
} }
userIDCacheMu.Unlock() userIDCacheMu.Unlock()
if refreshed := cachedUserID(key); refreshed != id { if refreshed := CachedUserID(key); refreshed != id {
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed) t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
} }

View File

@@ -0,0 +1,188 @@
package helps
import (
"net"
"net/http"
"strings"
"sync"
"time"
tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/proxy"
)
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
type utlsRoundTripper struct {
mu sync.Mutex
connections map[string]*http2.ClientConn
pending map[string]*sync.Cond
dialer proxy.Dialer
}
func newUtlsRoundTripper(proxyURL string) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct
if proxyURL != "" {
proxyDialer, mode, errBuild := proxyutil.BuildDialer(proxyURL)
if errBuild != nil {
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyURL, errBuild)
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
dialer = proxyDialer
}
}
return &utlsRoundTripper{
connections: make(map[string]*http2.ClientConn),
pending: make(map[string]*sync.Cond),
dialer: dialer,
}
}
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
t.mu.Lock()
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
t.mu.Unlock()
return h2Conn, nil
}
if cond, ok := t.pending[host]; ok {
cond.Wait()
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
t.mu.Unlock()
return h2Conn, nil
}
}
cond := sync.NewCond(&t.mu)
t.pending[host] = cond
t.mu.Unlock()
h2Conn, err := t.createConnection(host, addr)
t.mu.Lock()
defer t.mu.Unlock()
delete(t.pending, host)
cond.Broadcast()
if err != nil {
return nil, err
}
t.connections[host] = h2Conn
return h2Conn, nil
}
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
conn, err := t.dialer.Dial("tcp", addr)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{ServerName: host}
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, err
}
tr := &http2.Transport{}
h2Conn, err := tr.NewClientConn(tlsConn)
if err != nil {
tlsConn.Close()
return nil, err
}
return h2Conn, nil
}
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
hostname := req.URL.Hostname()
port := req.URL.Port()
if port == "" {
port = "443"
}
addr := net.JoinHostPort(hostname, port)
h2Conn, err := t.getOrCreateConnection(hostname, addr)
if err != nil {
return nil, err
}
resp, err := h2Conn.RoundTrip(req)
if err != nil {
t.mu.Lock()
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
delete(t.connections, hostname)
}
t.mu.Unlock()
return nil, err
}
return resp, nil
}
// anthropicHosts contains the hosts that should use utls Chrome TLS fingerprint.
var anthropicHosts = map[string]struct{}{
"api.anthropic.com": {},
}
// fallbackRoundTripper uses utls for Anthropic HTTPS hosts and falls back to
// standard transport for all other requests (non-HTTPS or non-Anthropic hosts).
type fallbackRoundTripper struct {
utls *utlsRoundTripper
fallback http.RoundTripper
}
func (f *fallbackRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme == "https" {
if _, ok := anthropicHosts[strings.ToLower(req.URL.Hostname())]; ok {
return f.utls.RoundTrip(req)
}
}
return f.fallback.RoundTrip(req)
}
// NewUtlsHTTPClient creates an HTTP client using utls Chrome TLS fingerprint.
// Use this for Claude API requests to match real Claude Code's TLS behavior.
// Falls back to standard transport for non-HTTPS requests.
func NewUtlsHTTPClient(cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
var proxyURL string
if auth != nil {
proxyURL = strings.TrimSpace(auth.ProxyURL)
}
if proxyURL == "" && cfg != nil {
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
utlsRT := newUtlsRoundTripper(proxyURL)
var standardTransport http.RoundTripper = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
}
if proxyURL != "" {
if transport := buildProxyTransport(proxyURL); transport != nil {
standardTransport = transport
}
}
client := &http.Client{
Transport: &fallbackRoundTripper{
utls: utlsRT,
fallback: standardTransport,
},
}
if timeout > 0 {
client.Timeout = timeout
}
return client
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow" iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -66,7 +67,7 @@ func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
if err := e.PrepareRequest(httpReq, auth); err != nil { if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err return nil, err
} }
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq) return httpClient.Do(httpReq)
} }
@@ -86,8 +87,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
baseURL = iflowauth.DefaultAPIBaseURL baseURL = iflowauth.DefaultAPIBaseURL
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
@@ -106,8 +107,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
} }
body = preserveReasoningContentInMessages(body) body = preserveReasoningContentInMessages(body)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -116,13 +117,18 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
return resp, err return resp, err
} }
applyIFlowHeaders(httpReq, apiKey, false) applyIFlowHeaders(httpReq, apiKey, false)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
authID = auth.ID authID = auth.ID
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint, URL: endpoint,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -134,10 +140,10 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
defer func() { defer func() {
@@ -145,25 +151,25 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
log.Errorf("iflow executor: close response body error: %v", errClose) log.Errorf("iflow executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
data, err := io.ReadAll(httpResp.Body) data, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data)) reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
// Ensure usage is recorded even if upstream omits usage metadata. // Ensure usage is recorded even if upstream omits usage metadata.
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
var param any var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve // Note: TranslateNonStream uses req.Model (original with suffix) to preserve
@@ -189,8 +195,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
baseURL = iflowauth.DefaultAPIBaseURL baseURL = iflowauth.DefaultAPIBaseURL
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
@@ -214,8 +220,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 { if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body) body = ensureToolsArray(body)
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -224,13 +230,18 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err return nil, err
} }
applyIFlowHeaders(httpReq, apiKey, true) applyIFlowHeaders(httpReq, apiKey, true)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
authID = auth.ID authID = auth.ID
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint, URL: endpoint,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -242,21 +253,21 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err return nil, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, _ := io.ReadAll(httpResp.Body) data, _ := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("iflow executor: close response body error: %v", errClose) log.Errorf("iflow executor: close response body error: %v", errClose)
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)} err = statusErr{code: httpResp.StatusCode, msg: string(data)}
return nil, err return nil, err
} }
@@ -275,9 +286,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok { if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks { for i := range chunks {
@@ -285,12 +296,12 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
} }
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} }
// Guarantee a usage record exists even if the stream never emitted usage data. // Guarantee a usage record exists even if the stream never emitted usage data.
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
}() }()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -303,17 +314,17 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
enc, err := tokenizerForModel(baseModel) enc, err := helps.TokenizerForModel(baseModel)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
} }
count, err := countOpenAIChatTokens(enc, body) count, err := helps.CountOpenAIChatTokens(enc, body)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
} }
usageJSON := buildOpenAIUsageJSON(count) usageJSON := helps.BuildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translated}, nil return cliproxyexecutor.Response{Payload: translated}, nil
} }

View File

@@ -15,7 +15,9 @@ import (
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi" kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -45,6 +47,11 @@ func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth
if strings.TrimSpace(token) != "" { if strings.TrimSpace(token) != "" {
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
} }
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil return nil
} }
@@ -60,7 +67,7 @@ func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
if err := e.PrepareRequest(httpReq, auth); err != nil { if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err return nil, err
} }
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq) return httpClient.Do(httpReq)
} }
@@ -76,8 +83,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
token := kimiCreds(auth) token := kimiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload originalPayloadSource := req.Payload
@@ -100,8 +107,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err return resp, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = normalizeKimiToolMessageLinks(body) body, err = normalizeKimiToolMessageLinks(body)
if err != nil { if err != nil {
return resp, err return resp, err
@@ -113,13 +120,18 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err return resp, err
} }
applyKimiHeadersWithAuth(httpReq, token, false, auth) applyKimiHeadersWithAuth(httpReq, token, false, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
authID = auth.ID authID = auth.ID
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -131,10 +143,10 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
defer func() { defer func() {
@@ -142,21 +154,21 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
log.Errorf("kimi executor: close response body error: %v", errClose) log.Errorf("kimi executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
data, err := io.ReadAll(httpResp.Body) data, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data)) reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
var param any var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve // Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility. // the original model name in the response for client compatibility.
@@ -176,8 +188,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
token := kimiCreds(auth) token := kimiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload originalPayloadSource := req.Payload
@@ -204,8 +216,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if err != nil { if err != nil {
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err) return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = normalizeKimiToolMessageLinks(body) body, err = normalizeKimiToolMessageLinks(body)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -217,13 +229,18 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err return nil, err
} }
applyKimiHeadersWithAuth(httpReq, token, true, auth) applyKimiHeadersWithAuth(httpReq, token, true, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
authID = auth.ID authID = auth.ID
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -235,17 +252,17 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err return nil, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("kimi executor: close response body error: %v", errClose) log.Errorf("kimi executor: close response body error: %v", errClose)
} }
@@ -265,9 +282,9 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok { if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks { for i := range chunks {
@@ -279,8 +296,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]} out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} }
}() }()

View File

@@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -65,15 +66,15 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
if err := e.PrepareRequest(httpReq, auth); err != nil { if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err return nil, err
} }
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq) return httpClient.Do(httpReq)
} }
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth) baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" { if baseURL == "" {
@@ -95,8 +96,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
originalPayload := originalPayloadSource originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream) originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream) translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
if opts.Alt == "responses/compact" { if opts.Alt == "responses/compact" {
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil { if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
translated = updated translated = updated
@@ -129,7 +130,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -141,10 +142,10 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
defer func() { defer func() {
@@ -152,23 +153,23 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
log.Errorf("openai compat executor: close response body error: %v", errClose) log.Errorf("openai compat executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
body, err := io.ReadAll(httpResp.Body) body, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, body) helps.AppendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body)) reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
// Ensure we at least record the request even if upstream doesn't return usage // Ensure we at least record the request even if upstream doesn't return usage
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
// Translate response back to source format when needed // Translate response back to source format when needed
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
@@ -179,8 +180,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth) baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" { if baseURL == "" {
@@ -197,8 +198,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
originalPayload := originalPayloadSource originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil { if err != nil {
@@ -232,7 +233,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -244,17 +245,17 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err return nil, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose) log.Errorf("openai compat executor: close response body error: %v", errClose)
} }
@@ -274,9 +275,9 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok { if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
if len(line) == 0 { if len(line) == 0 {
continue continue
@@ -294,12 +295,20 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
} }
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
// In case the upstream close the stream without a terminal [DONE] marker.
// Feed a synthetic done marker through the translator so pending
// response.completed events are still emitted exactly once.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
} }
// Ensure we record the request if no usage chunk was ever seen // Ensure we record the request if no usage chunk was ever seen
reporter.ensurePublished(ctx) reporter.EnsurePublished(ctx)
}() }()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
} }
@@ -318,17 +327,17 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
return cliproxyexecutor.Response{}, err return cliproxyexecutor.Response{}, err
} }
enc, err := tokenizerForModel(modelForCounting) enc, err := helps.TokenizerForModel(modelForCounting)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
} }
count, err := countOpenAIChatTokens(enc, translated) count, err := helps.CountOpenAIChatTokens(enc, translated)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
} }
usageJSON := buildOpenAIUsageJSON(count) usageJSON := helps.BuildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translatedUsage}, nil return cliproxyexecutor.Response{Payload: translatedUsage}, nil
} }

View File

@@ -13,7 +13,9 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -23,20 +25,12 @@ import (
) )
const ( const (
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)" qwenUserAgent = "QwenCode/0.14.2 (darwin; arm64)"
qwenRateLimitPerMin = 60 // 60 requests per minute per credential qwenRateLimitPerMin = 60 // 60 requests per minute per credential
qwenRateLimitWindow = time.Minute // sliding window duration qwenRateLimitWindow = time.Minute // sliding window duration
) )
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls. var qwenDefaultSystemMessage = []byte(`{"role":"system","content":[{"type":"text","text":"","cache_control":{"type":"ephemeral"}}]}`)
var qwenBeijingLoc = func() *time.Location {
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil || loc == nil {
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
return time.FixedZone("CST", 8*3600)
}
return loc
}()
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion. // qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
var qwenQuotaCodes = map[string]struct{}{ var qwenQuotaCodes = map[string]struct{}{
@@ -152,20 +146,110 @@ func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int,
// Qwen returns 403 for quota errors, 429 for rate limits // Qwen returns 403 for quota errors, 429 for rate limits
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) { if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
cooldown := timeUntilNextDay() // Do not force an excessively long retry-after (e.g. until tomorrow), otherwise
retryAfter = &cooldown // the global request-retry scheduler may skip retries due to max-retry-interval.
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown) helps.LogWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d)", httpCode, errCode)
} }
return errCode, retryAfter return errCode, retryAfter
} }
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8). // ensureQwenSystemMessage ensures the request has a single system message at the beginning.
// Qwen's daily quota resets at 00:00 Beijing time. // It always injects the default system prompt and merges any user-provided system messages
func timeUntilNextDay() time.Duration { // into the injected system message content to satisfy Qwen's strict message ordering rules.
now := time.Now() func ensureQwenSystemMessage(payload []byte) ([]byte, error) {
nowLocal := now.In(qwenBeijingLoc) isInjectedSystemPart := func(part gjson.Result) bool {
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc) if !part.Exists() || !part.IsObject() {
return tomorrow.Sub(now) return false
}
if !strings.EqualFold(part.Get("type").String(), "text") {
return false
}
if !strings.EqualFold(part.Get("cache_control.type").String(), "ephemeral") {
return false
}
text := part.Get("text").String()
return text == "" || text == "You are Qwen Code."
}
defaultParts := gjson.ParseBytes(qwenDefaultSystemMessage).Get("content")
var systemParts []any
if defaultParts.Exists() && defaultParts.IsArray() {
for _, part := range defaultParts.Array() {
systemParts = append(systemParts, part.Value())
}
}
if len(systemParts) == 0 {
systemParts = append(systemParts, map[string]any{
"type": "text",
"text": "You are Qwen Code.",
"cache_control": map[string]any{
"type": "ephemeral",
},
})
}
appendSystemContent := func(content gjson.Result) {
makeTextPart := func(text string) map[string]any {
return map[string]any{
"type": "text",
"text": text,
}
}
if !content.Exists() || content.Type == gjson.Null {
return
}
if content.IsArray() {
for _, part := range content.Array() {
if part.Type == gjson.String {
systemParts = append(systemParts, makeTextPart(part.String()))
continue
}
if isInjectedSystemPart(part) {
continue
}
systemParts = append(systemParts, part.Value())
}
return
}
if content.Type == gjson.String {
systemParts = append(systemParts, makeTextPart(content.String()))
return
}
if content.IsObject() {
if isInjectedSystemPart(content) {
return
}
systemParts = append(systemParts, content.Value())
return
}
systemParts = append(systemParts, makeTextPart(content.String()))
}
messages := gjson.GetBytes(payload, "messages")
var nonSystemMessages []any
if messages.Exists() && messages.IsArray() {
for _, msg := range messages.Array() {
if strings.EqualFold(msg.Get("role").String(), "system") {
appendSystemContent(msg.Get("content"))
continue
}
nonSystemMessages = append(nonSystemMessages, msg.Value())
}
}
newMessages := make([]any, 0, 1+len(nonSystemMessages))
newMessages = append(newMessages, map[string]any{
"role": "system",
"content": systemParts,
})
newMessages = append(newMessages, nonSystemMessages...)
updated, errSet := sjson.SetBytes(payload, "messages", newMessages)
if errSet != nil {
return nil, fmt.Errorf("qwen executor: set system message failed: %w", errSet)
}
return updated, nil
} }
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions. // QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
@@ -202,7 +286,7 @@ func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
if err := e.PrepareRequest(httpReq, auth); err != nil { if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err return nil, err
} }
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq) return httpClient.Do(httpReq)
} }
@@ -217,7 +301,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
authID = auth.ID authID = auth.ID
} }
if err := checkQwenRateLimit(authID); err != nil { if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID)) helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return resp, err return resp, err
} }
@@ -228,8 +312,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
baseURL = "https://portal.qwen.ai/v1" baseURL = "https://portal.qwen.ai/v1"
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
@@ -247,8 +331,12 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err return resp, err
} }
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = ensureQwenSystemMessage(body)
if err != nil {
return resp, err
}
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -256,12 +344,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err return resp, err
} }
applyQwenHeaders(httpReq, token, false) applyQwenHeaders(httpReq, token, false)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authLabel, authType, authValue string var authLabel, authType, authValue string
if auth != nil { if auth != nil {
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -273,10 +366,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
defer func() { defer func() {
@@ -284,23 +377,23 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
log.Errorf("qwen executor: close response body error: %v", errClose) log.Errorf("qwen executor: close response body error: %v", errClose)
} }
}() }()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b) errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter} err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return resp, err return resp, err
} }
data, err := io.ReadAll(httpResp.Body) data, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, data) helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data)) reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
var param any var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve // Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility. // the original model name in the response for client compatibility.
@@ -320,7 +413,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
authID = auth.ID authID = auth.ID
} }
if err := checkQwenRateLimit(authID); err != nil { if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID)) helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return nil, err return nil, err
} }
@@ -331,8 +424,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
baseURL = "https://portal.qwen.ai/v1" baseURL = "https://portal.qwen.ai/v1"
} }
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
@@ -350,15 +443,19 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err return nil, err
} }
toolsResult := gjson.GetBytes(body, "tools") // toolsResult := gjson.GetBytes(body, "tools")
// I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response. // I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response.
// This will have no real consequences. It's just to scare Qwen3. // This will have no real consequences. It's just to scare Qwen3.
if (toolsResult.IsArray() && len(toolsResult.Array()) == 0) || !toolsResult.Exists() { // if (toolsResult.IsArray() && len(toolsResult.Array()) == 0) || !toolsResult.Exists() {
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`)) // body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
} // }
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true) body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
requestedModel := payloadRequestedModel(opts, req.Model) requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = ensureQwenSystemMessage(body)
if err != nil {
return nil, err
}
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -366,12 +463,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err return nil, err
} }
applyQwenHeaders(httpReq, token, true) applyQwenHeaders(httpReq, token, true)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authLabel, authType, authValue string var authLabel, authType, authValue string
if auth != nil { if auth != nil {
authLabel = auth.Label authLabel = auth.Label
authType, authValue = auth.AccountInfo() authType, authValue = auth.AccountInfo()
} }
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url, URL: url,
Method: http.MethodPost, Method: http.MethodPost,
Headers: httpReq.Header.Clone(), Headers: httpReq.Header.Clone(),
@@ -383,19 +485,19 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue, AuthValue: authValue,
}) })
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq) httpResp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err return nil, err
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) helps.AppendAPIResponseChunk(ctx, e.cfg, b)
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b) errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose) log.Errorf("qwen executor: close response body error: %v", errClose)
} }
@@ -415,9 +517,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
var param any var param any
for scanner.Scan() { for scanner.Scan() {
line := scanner.Bytes() line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line) helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok { if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail) reporter.Publish(ctx, detail)
} }
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks { for i := range chunks {
@@ -429,8 +531,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]} out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx) reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan} out <- cliproxyexecutor.StreamChunk{Err: errScan}
} }
}() }()
@@ -449,17 +551,17 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
modelName = baseModel modelName = baseModel
} }
enc, err := tokenizerForModel(modelName) enc, err := helps.TokenizerForModel(modelName)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
} }
count, err := countOpenAIChatTokens(enc, body) count, err := helps.CountOpenAIChatTokens(enc, body)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
} }
usageJSON := buildOpenAIUsageJSON(count) usageJSON := helps.BuildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translated}, nil return cliproxyexecutor.Response{Payload: translated}, nil
} }
@@ -505,20 +607,23 @@ func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
} }
func applyQwenHeaders(r *http.Request, token string, stream bool) { func applyQwenHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", qwenUserAgent)
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0") r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
r.Header.Set("Sec-Fetch-Mode", "cors") r.Header.Set("User-Agent", qwenUserAgent)
r.Header.Set("X-Stainless-Lang", "js") r.Header.Set("X-Stainless-Lang", "js")
r.Header.Set("X-Stainless-Arch", "arm64") r.Header.Set("Accept-Language", "*")
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
r.Header.Set("X-Dashscope-Cachecontrol", "enable") r.Header.Set("X-Dashscope-Cachecontrol", "enable")
r.Header.Set("X-Stainless-Retry-Count", "0")
r.Header.Set("X-Stainless-Os", "MacOS") r.Header.Set("X-Stainless-Os", "MacOS")
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth") r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
r.Header.Set("X-Stainless-Arch", "arm64")
r.Header.Set("X-Stainless-Runtime", "node") r.Header.Set("X-Stainless-Runtime", "node")
r.Header.Set("X-Stainless-Retry-Count", "0")
r.Header.Set("Accept-Encoding", "gzip, deflate")
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
r.Header.Set("Sec-Fetch-Mode", "cors")
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Connection", "keep-alive")
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
if stream { if stream {
r.Header.Set("Accept", "text/event-stream") r.Header.Set("Accept", "text/event-stream")
@@ -527,6 +632,26 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
} }
func normaliseQwenBaseURL(resourceURL string) string {
raw := strings.TrimSpace(resourceURL)
if raw == "" {
return ""
}
normalized := raw
lower := strings.ToLower(normalized)
if !strings.HasPrefix(lower, "http://") && !strings.HasPrefix(lower, "https://") {
normalized = "https://" + normalized
}
normalized = strings.TrimRight(normalized, "/")
if !strings.HasSuffix(strings.ToLower(normalized), "/v1") {
normalized += "/v1"
}
return normalized
}
func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) { func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) {
if a == nil { if a == nil {
return "", "" return "", ""
@@ -544,7 +669,7 @@ func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) {
token = v token = v
} }
if v, ok := a.Metadata["resource_url"].(string); ok { if v, ok := a.Metadata["resource_url"].(string); ok {
baseURL = fmt.Sprintf("https://%s/v1", v) baseURL = normaliseQwenBaseURL(v)
} }
} }
return return

View File

@@ -1,9 +1,13 @@
package executor package executor
import ( import (
"context"
"net/http"
"testing" "testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/tidwall/gjson"
) )
func TestQwenExecutorParseSuffix(t *testing.T) { func TestQwenExecutorParseSuffix(t *testing.T) {
@@ -28,3 +32,180 @@ func TestQwenExecutorParseSuffix(t *testing.T) {
}) })
} }
} }
func TestEnsureQwenSystemMessage_MergeStringSystem(t *testing.T) {
payload := []byte(`{
"model": "qwen3.6-plus",
"stream": true,
"messages": [
{ "role": "system", "content": "ABCDEFG" },
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
]
}`)
out, err := ensureQwenSystemMessage(payload)
if err != nil {
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
}
msgs := gjson.GetBytes(out, "messages").Array()
if len(msgs) != 2 {
t.Fatalf("messages length = %d, want 2", len(msgs))
}
if msgs[0].Get("role").String() != "system" {
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
}
parts := msgs[0].Get("content").Array()
if len(parts) != 2 {
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
}
if parts[0].Get("type").String() != "text" || parts[0].Get("cache_control.type").String() != "ephemeral" {
t.Fatalf("messages[0].content[0] = %s, want injected system part", parts[0].Raw)
}
if text := parts[0].Get("text").String(); text != "" && text != "You are Qwen Code." {
t.Fatalf("messages[0].content[0].text = %q, want empty string or default prompt", text)
}
if parts[1].Get("type").String() != "text" || parts[1].Get("text").String() != "ABCDEFG" {
t.Fatalf("messages[0].content[1] = %s, want text part with ABCDEFG", parts[1].Raw)
}
if msgs[1].Get("role").String() != "user" {
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
}
}
func TestEnsureQwenSystemMessage_MergeObjectSystem(t *testing.T) {
payload := []byte(`{
"messages": [
{ "role": "system", "content": { "type": "text", "text": "ABCDEFG" } },
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
]
}`)
out, err := ensureQwenSystemMessage(payload)
if err != nil {
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
}
msgs := gjson.GetBytes(out, "messages").Array()
if len(msgs) != 2 {
t.Fatalf("messages length = %d, want 2", len(msgs))
}
parts := msgs[0].Get("content").Array()
if len(parts) != 2 {
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
}
if parts[1].Get("text").String() != "ABCDEFG" {
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "ABCDEFG")
}
}
func TestEnsureQwenSystemMessage_PrependsWhenMissing(t *testing.T) {
payload := []byte(`{
"messages": [
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
]
}`)
out, err := ensureQwenSystemMessage(payload)
if err != nil {
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
}
msgs := gjson.GetBytes(out, "messages").Array()
if len(msgs) != 2 {
t.Fatalf("messages length = %d, want 2", len(msgs))
}
if msgs[0].Get("role").String() != "system" {
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
}
if !msgs[0].Get("content").IsArray() || len(msgs[0].Get("content").Array()) == 0 {
t.Fatalf("messages[0].content = %s, want non-empty array", msgs[0].Get("content").Raw)
}
if msgs[1].Get("role").String() != "user" {
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
}
}
func TestEnsureQwenSystemMessage_MergesMultipleSystemMessages(t *testing.T) {
payload := []byte(`{
"messages": [
{ "role": "system", "content": "A" },
{ "role": "user", "content": [ { "type": "text", "text": "hi" } ] },
{ "role": "system", "content": "B" }
]
}`)
out, err := ensureQwenSystemMessage(payload)
if err != nil {
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
}
msgs := gjson.GetBytes(out, "messages").Array()
if len(msgs) != 2 {
t.Fatalf("messages length = %d, want 2", len(msgs))
}
parts := msgs[0].Get("content").Array()
if len(parts) != 3 {
t.Fatalf("messages[0].content length = %d, want 3", len(parts))
}
if parts[1].Get("text").String() != "A" {
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "A")
}
if parts[2].Get("text").String() != "B" {
t.Fatalf("messages[0].content[2].text = %q, want %q", parts[2].Get("text").String(), "B")
}
}
func TestWrapQwenError_InsufficientQuotaDoesNotSetRetryAfter(t *testing.T) {
body := []byte(`{"error":{"code":"insufficient_quota","message":"You exceeded your current quota","type":"insufficient_quota"}}`)
code, retryAfter := wrapQwenError(context.Background(), http.StatusTooManyRequests, body)
if code != http.StatusTooManyRequests {
t.Fatalf("wrapQwenError status = %d, want %d", code, http.StatusTooManyRequests)
}
if retryAfter != nil {
t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter)
}
}
func TestWrapQwenError_Maps403QuotaTo429WithoutRetryAfter(t *testing.T) {
body := []byte(`{"error":{"code":"insufficient_quota","message":"You exceeded your current quota","type":"insufficient_quota"}}`)
code, retryAfter := wrapQwenError(context.Background(), http.StatusForbidden, body)
if code != http.StatusTooManyRequests {
t.Fatalf("wrapQwenError status = %d, want %d", code, http.StatusTooManyRequests)
}
if retryAfter != nil {
t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter)
}
}
func TestQwenCreds_NormalizesResourceURL(t *testing.T) {
tests := []struct {
name string
resourceURL string
wantBaseURL string
}{
{"host only", "portal.qwen.ai", "https://portal.qwen.ai/v1"},
{"scheme no v1", "https://portal.qwen.ai", "https://portal.qwen.ai/v1"},
{"scheme with v1", "https://portal.qwen.ai/v1", "https://portal.qwen.ai/v1"},
{"scheme with v1 slash", "https://portal.qwen.ai/v1/", "https://portal.qwen.ai/v1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"access_token": "test-token",
"resource_url": tt.resourceURL,
},
}
token, baseURL := qwenCreds(auth)
if token != "test-token" {
t.Fatalf("qwenCreds token = %q, want %q", token, "test-token")
}
if baseURL != tt.wantBaseURL {
t.Fatalf("qwenCreds baseURL = %q, want %q", baseURL, tt.wantBaseURL)
}
})
}
}

View File

@@ -32,16 +32,24 @@ type GitTokenStore struct {
repoDir string repoDir string
configDir string configDir string
remote string remote string
branch string
username string username string
password string password string
lastGC time.Time lastGC time.Time
} }
type resolvedRemoteBranch struct {
name plumbing.ReferenceName
hash plumbing.Hash
}
// NewGitTokenStore creates a token store that saves credentials to disk through the // NewGitTokenStore creates a token store that saves credentials to disk through the
// TokenStorage implementation embedded in the token record. // TokenStorage implementation embedded in the token record.
func NewGitTokenStore(remote, username, password string) *GitTokenStore { // When branch is non-empty, clone/pull/push operations target that branch instead of the remote default.
func NewGitTokenStore(remote, username, password, branch string) *GitTokenStore {
return &GitTokenStore{ return &GitTokenStore{
remote: remote, remote: remote,
branch: strings.TrimSpace(branch),
username: username, username: username,
password: password, password: password,
} }
@@ -120,7 +128,11 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock() s.dirLock.Unlock()
return fmt.Errorf("git token store: create repo dir: %w", errMk) return fmt.Errorf("git token store: create repo dir: %w", errMk)
} }
if _, errClone := git.PlainClone(repoDir, &git.CloneOptions{Auth: authMethod, URL: s.remote}); errClone != nil { cloneOpts := &git.CloneOptions{Auth: authMethod, URL: s.remote}
if s.branch != "" {
cloneOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
}
if _, errClone := git.PlainClone(repoDir, cloneOpts); errClone != nil {
if errors.Is(errClone, transport.ErrEmptyRemoteRepository) { if errors.Is(errClone, transport.ErrEmptyRemoteRepository) {
_ = os.RemoveAll(gitDir) _ = os.RemoveAll(gitDir)
repo, errInit := git.PlainInit(repoDir, false) repo, errInit := git.PlainInit(repoDir, false)
@@ -128,6 +140,13 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock() s.dirLock.Unlock()
return fmt.Errorf("git token store: init empty repo: %w", errInit) return fmt.Errorf("git token store: init empty repo: %w", errInit)
} }
if s.branch != "" {
headRef := plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(s.branch))
if errHead := repo.Storer.SetReference(headRef); errHead != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: set head to branch %s: %w", s.branch, errHead)
}
}
if _, errRemote := repo.Remote("origin"); errRemote != nil { if _, errRemote := repo.Remote("origin"); errRemote != nil {
if _, errCreate := repo.CreateRemote(&config.RemoteConfig{ if _, errCreate := repo.CreateRemote(&config.RemoteConfig{
Name: "origin", Name: "origin",
@@ -176,16 +195,39 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock() s.dirLock.Unlock()
return fmt.Errorf("git token store: worktree: %w", errWorktree) return fmt.Errorf("git token store: worktree: %w", errWorktree)
} }
if errPull := worktree.Pull(&git.PullOptions{Auth: authMethod, RemoteName: "origin"}); errPull != nil { if s.branch != "" {
if errCheckout := s.checkoutConfiguredBranch(repo, worktree, authMethod); errCheckout != nil {
s.dirLock.Unlock()
return errCheckout
}
} else {
// When branch is unset, ensure the working tree follows the remote default branch
if err := checkoutRemoteDefaultBranch(repo, worktree, authMethod); err != nil {
if !shouldFallbackToCurrentBranch(repo, err) {
s.dirLock.Unlock()
return fmt.Errorf("git token store: checkout remote default: %w", err)
}
}
}
pullOpts := &git.PullOptions{Auth: authMethod, RemoteName: "origin"}
if s.branch != "" {
pullOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
}
if errPull := worktree.Pull(pullOpts); errPull != nil {
switch { switch {
case errors.Is(errPull, git.NoErrAlreadyUpToDate), case errors.Is(errPull, git.NoErrAlreadyUpToDate),
errors.Is(errPull, git.ErrUnstagedChanges), errors.Is(errPull, git.ErrUnstagedChanges),
errors.Is(errPull, git.ErrNonFastForwardUpdate): errors.Is(errPull, git.ErrNonFastForwardUpdate):
// Ignore clean syncs, local edits, and remote divergence—local changes win. // Ignore clean syncs, local edits, and remote divergence—local changes win.
case errors.Is(errPull, transport.ErrAuthenticationRequired), case errors.Is(errPull, transport.ErrAuthenticationRequired),
errors.Is(errPull, plumbing.ErrReferenceNotFound),
errors.Is(errPull, transport.ErrEmptyRemoteRepository): errors.Is(errPull, transport.ErrEmptyRemoteRepository):
// Ignore authentication prompts and empty remote references on initial sync. // Ignore authentication prompts and empty remote references on initial sync.
case errors.Is(errPull, plumbing.ErrReferenceNotFound):
if s.branch != "" {
s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull)
}
// Ignore missing references only when following the remote default branch.
default: default:
s.dirLock.Unlock() s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull) return fmt.Errorf("git token store: pull: %w", errPull)
@@ -446,6 +488,7 @@ func (s *GitTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
if email, ok := metadata["email"].(string); ok && email != "" { if email, ok := metadata["email"].(string); ok && email != "" {
auth.Attributes["email"] = email auth.Attributes["email"] = email
} }
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
return auth, nil return auth, nil
} }
@@ -553,6 +596,192 @@ func (s *GitTokenStore) relativeToRepo(path string) (string, error) {
return rel, nil return rel, nil
} }
func (s *GitTokenStore) checkoutConfiguredBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
branchRefName := plumbing.NewBranchReferenceName(s.branch)
headRef, errHead := repo.Head()
switch {
case errHead == nil && headRef.Name() == branchRefName:
return nil
case errHead != nil && !errors.Is(errHead, plumbing.ErrReferenceNotFound):
return fmt.Errorf("git token store: get head: %w", errHead)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err == nil {
return nil
} else if _, errRef := repo.Reference(branchRefName, true); errRef == nil {
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
} else if !errors.Is(errRef, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("git token store: inspect branch %s: %w", s.branch, errRef)
} else if err := s.checkoutConfiguredRemoteTrackingBranch(repo, worktree, branchRefName, authMethod); err != nil {
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
}
return nil
}
func (s *GitTokenStore) checkoutConfiguredRemoteTrackingBranch(repo *git.Repository, worktree *git.Worktree, branchRefName plumbing.ReferenceName, authMethod transport.AuthMethod) error {
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + s.branch)
remoteRef, err := repo.Reference(remoteRefName, true)
if errors.Is(err, plumbing.ErrReferenceNotFound) {
if errSync := syncRemoteReferences(repo, authMethod); errSync != nil {
return fmt.Errorf("sync remote refs: %w", errSync)
}
remoteRef, err = repo.Reference(remoteRefName, true)
}
if err != nil {
return err
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: remoteRef.Hash()}); err != nil {
return err
}
cfg, err := repo.Config()
if err != nil {
return fmt.Errorf("git token store: repo config: %w", err)
}
if _, ok := cfg.Branches[s.branch]; !ok {
cfg.Branches[s.branch] = &config.Branch{Name: s.branch}
}
cfg.Branches[s.branch].Remote = "origin"
cfg.Branches[s.branch].Merge = branchRefName
if err := repo.SetConfig(cfg); err != nil {
return fmt.Errorf("git token store: set branch config: %w", err)
}
return nil
}
func syncRemoteReferences(repo *git.Repository, authMethod transport.AuthMethod) error {
if err := repo.Fetch(&git.FetchOptions{Auth: authMethod, RemoteName: "origin"}); err != nil && !errors.Is(err, git.NoErrAlreadyUpToDate) {
return err
}
return nil
}
// resolveRemoteDefaultBranch queries the origin remote to determine the remote's default branch
// (the target of HEAD) and returns the corresponding local branch reference name (e.g. refs/heads/master).
func resolveRemoteDefaultBranch(repo *git.Repository, authMethod transport.AuthMethod) (resolvedRemoteBranch, error) {
if err := syncRemoteReferences(repo, authMethod); err != nil {
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: sync remote refs: %w", err)
}
remote, err := repo.Remote("origin")
if err != nil {
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: get remote: %w", err)
}
refs, err := remote.List(&git.ListOptions{Auth: authMethod})
if err != nil {
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
return resolved, nil
}
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: list remote refs: %w", err)
}
for _, r := range refs {
if r.Name() == plumbing.HEAD {
if r.Type() == plumbing.SymbolicReference {
if target, ok := normalizeRemoteBranchReference(r.Target()); ok {
return resolvedRemoteBranch{name: target}, nil
}
}
s := r.String()
if idx := strings.Index(s, "->"); idx != -1 {
if target, ok := normalizeRemoteBranchReference(plumbing.ReferenceName(strings.TrimSpace(s[idx+2:]))); ok {
return resolvedRemoteBranch{name: target}, nil
}
}
}
}
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
return resolved, nil
}
for _, r := range refs {
if normalized, ok := normalizeRemoteBranchReference(r.Name()); ok {
return resolvedRemoteBranch{name: normalized, hash: r.Hash()}, nil
}
}
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: remote default branch not found")
}
func resolveRemoteDefaultBranchFromLocal(repo *git.Repository) (resolvedRemoteBranch, bool) {
ref, err := repo.Reference(plumbing.ReferenceName("refs/remotes/origin/HEAD"), true)
if err != nil || ref.Type() != plumbing.SymbolicReference {
return resolvedRemoteBranch{}, false
}
target, ok := normalizeRemoteBranchReference(ref.Target())
if !ok {
return resolvedRemoteBranch{}, false
}
return resolvedRemoteBranch{name: target}, true
}
func normalizeRemoteBranchReference(name plumbing.ReferenceName) (plumbing.ReferenceName, bool) {
switch {
case strings.HasPrefix(name.String(), "refs/heads/"):
return name, true
case strings.HasPrefix(name.String(), "refs/remotes/origin/"):
return plumbing.NewBranchReferenceName(strings.TrimPrefix(name.String(), "refs/remotes/origin/")), true
default:
return "", false
}
}
func shouldFallbackToCurrentBranch(repo *git.Repository, err error) bool {
if !errors.Is(err, transport.ErrAuthenticationRequired) && !errors.Is(err, transport.ErrEmptyRemoteRepository) {
return false
}
_, headErr := repo.Head()
return headErr == nil
}
// checkoutRemoteDefaultBranch ensures the working tree is checked out to the remote's default branch
// (the branch target of origin/HEAD). If the local branch does not exist it will be created to track
// the remote branch.
func checkoutRemoteDefaultBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
resolved, err := resolveRemoteDefaultBranch(repo, authMethod)
if err != nil {
return err
}
branchRefName := resolved.name
// If HEAD already points to the desired branch, nothing to do.
headRef, errHead := repo.Head()
if errHead == nil && headRef.Name() == branchRefName {
return nil
}
// If local branch exists, attempt a checkout
if _, err := repo.Reference(branchRefName, true); err == nil {
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err != nil {
return fmt.Errorf("checkout branch %s: %w", branchRefName.String(), err)
}
return nil
}
// Try to find the corresponding remote tracking ref (refs/remotes/origin/<name>)
branchShort := strings.TrimPrefix(branchRefName.String(), "refs/heads/")
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + branchShort)
hash := resolved.hash
if remoteRef, err := repo.Reference(remoteRefName, true); err == nil {
hash = remoteRef.Hash()
} else if err != nil && !errors.Is(err, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("checkout remote default: remote ref %s: %w", remoteRefName.String(), err)
}
if hash == plumbing.ZeroHash {
return fmt.Errorf("checkout remote default: remote ref %s not found", remoteRefName.String())
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: hash}); err != nil {
return fmt.Errorf("checkout create branch %s: %w", branchRefName.String(), err)
}
cfg, err := repo.Config()
if err != nil {
return fmt.Errorf("git token store: repo config: %w", err)
}
if _, ok := cfg.Branches[branchShort]; !ok {
cfg.Branches[branchShort] = &config.Branch{Name: branchShort}
}
cfg.Branches[branchShort].Remote = "origin"
cfg.Branches[branchShort].Merge = branchRefName
if err := repo.SetConfig(cfg); err != nil {
return fmt.Errorf("git token store: set branch config: %w", err)
}
return nil
}
func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error { func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error {
repoDir := s.repoDirSnapshot() repoDir := s.repoDirSnapshot()
if repoDir == "" { if repoDir == "" {
@@ -618,7 +847,16 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
return errRewrite return errRewrite
} }
s.maybeRunGC(repo) s.maybeRunGC(repo)
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil { pushOpts := &git.PushOptions{Auth: s.gitAuth(), Force: true}
if s.branch != "" {
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec("refs/heads/" + s.branch + ":refs/heads/" + s.branch)}
} else {
// When branch is unset, pin push to the currently checked-out branch.
if headRef, err := repo.Head(); err == nil {
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec(headRef.Name().String() + ":" + headRef.Name().String())}
}
}
if err = repo.Push(pushOpts); err != nil {
if errors.Is(err, git.NoErrAlreadyUpToDate) { if errors.Is(err, git.NoErrAlreadyUpToDate) {
return nil return nil
} }

View File

@@ -0,0 +1,585 @@
package store
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/go-git/go-git/v6"
gitconfig "github.com/go-git/go-git/v6/config"
"github.com/go-git/go-git/v6/plumbing"
"github.com/go-git/go-git/v6/plumbing/object"
)
type testBranchSpec struct {
name string
contents string
}
func TestEnsureRepositoryUsesRemoteDefaultBranchWhenBranchNotConfigured(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
testBranchSpec{name: "release/2026", contents: "release branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch\n")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository second call: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryUsesConfiguredBranchWhenExplicitlySet(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
testBranchSpec{name: "release/2026", contents: "release branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "release/2026")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository second call: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "missing-branch")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
err := store.EnsureRepository()
if err == nil {
t.Fatal("EnsureRepository succeeded, want error for nonexistent configured branch")
}
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranchOnExistingRepositoryPull(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
reopened := NewGitTokenStore(remoteDir, "", "", "missing-branch")
reopened.SetBaseDir(baseDir)
err := reopened.EnsureRepository()
if err == nil {
t.Fatal("EnsureRepository succeeded on reopen, want error for nonexistent configured branch")
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "trunk")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryInitializesEmptyRemoteUsingConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := filepath.Join(root, "remote.git")
if _, err := git.PlainInit(remoteDir, true); err != nil {
t.Fatalf("init bare remote: %v", err)
}
branch := "feature/gemini-fix"
store := NewGitTokenStore(remoteDir, "", "", branch)
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), branch)
assertRemoteBranchExistsWithCommit(t, remoteDir, branch)
assertRemoteBranchDoesNotExist(t, remoteDir, "master")
}
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
reopened := NewGitTokenStore(remoteDir, "", "", "develop")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository reopen: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
workspaceDir := filepath.Join(root, "workspace")
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local develop update\n"), 0o600); err != nil {
t.Fatalf("write local branch marker: %v", err)
}
reopened.mu.Lock()
err := reopened.commitAndPushLocked("Update develop branch marker", "branch.txt")
reopened.mu.Unlock()
if err != nil {
t.Fatalf("commitAndPushLocked: %v", err)
}
assertRepositoryHeadBranch(t, workspaceDir, "develop")
assertRemoteBranchContents(t, remoteDir, "develop", "local develop update\n")
assertRemoteBranchContents(t, remoteDir, "master", "remote master branch\n")
}
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranchCreatedAfterClone(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
advanceRemoteBranchFromNewBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch\n", "create release")
reopened := NewGitTokenStore(remoteDir, "", "", "release/2026")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository reopen: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
}
func TestEnsureRepositoryResetsToRemoteDefaultWhenBranchUnset(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
// First store pins to develop and prepares local workspace
storePinned := NewGitTokenStore(remoteDir, "", "", "develop")
storePinned.SetBaseDir(baseDir)
if err := storePinned.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository pinned: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
// Second store has branch unset and should reset local workspace to remote default (master)
storeDefault := NewGitTokenStore(remoteDir, "", "", "")
storeDefault.SetBaseDir(baseDir)
if err := storeDefault.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository default: %v", err)
}
// Local HEAD should now follow remote default (master)
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "master")
// Make a local change and push using the store with branch unset; push should update remote master
workspaceDir := filepath.Join(root, "workspace")
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local master update\n"), 0o600); err != nil {
t.Fatalf("write local master marker: %v", err)
}
storeDefault.mu.Lock()
if err := storeDefault.commitAndPushLocked("Update master marker", "branch.txt"); err != nil {
storeDefault.mu.Unlock()
t.Fatalf("commitAndPushLocked: %v", err)
}
storeDefault.mu.Unlock()
assertRemoteBranchContents(t, remoteDir, "master", "local master update\n")
}
func TestEnsureRepositoryFollowsRenamedRemoteDefaultBranchWhenAvailable(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "main", contents: "remote main branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
setRemoteHeadBranch(t, remoteDir, "main")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "main", "remote main branch updated\n", "advance main")
reopened := NewGitTokenStore(remoteDir, "", "", "")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository after remote default rename: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "main", "remote main branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "main")
}
func TestEnsureRepositoryKeepsCurrentBranchWhenRemoteDefaultCannotBeResolved(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
pinned := NewGitTokenStore(remoteDir, "", "", "develop")
pinned.SetBaseDir(baseDir)
if err := pinned.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository pinned: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="git"`)
http.Error(w, "auth required", http.StatusUnauthorized)
}))
defer authServer.Close()
repo, err := git.PlainOpen(filepath.Join(root, "workspace"))
if err != nil {
t.Fatalf("open workspace repo: %v", err)
}
cfg, err := repo.Config()
if err != nil {
t.Fatalf("read repo config: %v", err)
}
cfg.Remotes["origin"].URLs = []string{authServer.URL}
if err := repo.SetConfig(cfg); err != nil {
t.Fatalf("set repo config: %v", err)
}
reopened := NewGitTokenStore(remoteDir, "", "", "")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository default branch fallback: %v", err)
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "develop")
}
func setupGitRemoteRepository(t *testing.T, root, defaultBranch string, branches ...testBranchSpec) string {
t.Helper()
remoteDir := filepath.Join(root, "remote.git")
if _, err := git.PlainInit(remoteDir, true); err != nil {
t.Fatalf("init bare remote: %v", err)
}
seedDir := filepath.Join(root, "seed")
seedRepo, err := git.PlainInit(seedDir, false)
if err != nil {
t.Fatalf("init seed repo: %v", err)
}
if err := seedRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
t.Fatalf("set seed HEAD: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
defaultSpec, ok := findBranchSpec(branches, defaultBranch)
if !ok {
t.Fatalf("missing default branch spec for %q", defaultBranch)
}
commitBranchMarker(t, seedDir, worktree, defaultSpec, "seed default branch")
for _, branch := range branches {
if branch.name == defaultBranch {
continue
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(defaultBranch)}); err != nil {
t.Fatalf("checkout default branch %s: %v", defaultBranch, err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch.name), Create: true}); err != nil {
t.Fatalf("create branch %s: %v", branch.name, err)
}
commitBranchMarker(t, seedDir, worktree, branch, "seed branch "+branch.name)
}
if _, err := seedRepo.CreateRemote(&gitconfig.RemoteConfig{Name: "origin", URLs: []string{remoteDir}}); err != nil {
t.Fatalf("create origin remote: %v", err)
}
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{gitconfig.RefSpec("refs/heads/*:refs/heads/*")},
}); err != nil {
t.Fatalf("push seed branches: %v", err)
}
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
t.Fatalf("set remote HEAD: %v", err)
}
return remoteDir
}
func commitBranchMarker(t *testing.T, seedDir string, worktree *git.Worktree, branch testBranchSpec, message string) {
t.Helper()
if err := os.WriteFile(filepath.Join(seedDir, "branch.txt"), []byte(branch.contents), 0o600); err != nil {
t.Fatalf("write branch marker for %s: %v", branch.name, err)
}
if _, err := worktree.Add("branch.txt"); err != nil {
t.Fatalf("add branch marker for %s: %v", branch.name, err)
}
if _, err := worktree.Commit(message, &git.CommitOptions{
Author: &object.Signature{
Name: "CLIProxyAPI",
Email: "cliproxy@local",
When: time.Unix(1711929600, 0),
},
}); err != nil {
t.Fatalf("commit branch marker for %s: %v", branch.name, err)
}
}
func advanceRemoteBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
t.Helper()
seedRepo, err := git.PlainOpen(seedDir)
if err != nil {
t.Fatalf("open seed repo: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch)}); err != nil {
t.Fatalf("checkout branch %s: %v", branch, err)
}
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
},
}); err != nil {
t.Fatalf("push branch %s update to %s: %v", branch, remoteDir, err)
}
}
func advanceRemoteBranchFromNewBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
t.Helper()
seedRepo, err := git.PlainOpen(seedDir)
if err != nil {
t.Fatalf("open seed repo: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName("master")}); err != nil {
t.Fatalf("checkout master before creating %s: %v", branch, err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch), Create: true}); err != nil {
t.Fatalf("create branch %s: %v", branch, err)
}
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
},
}); err != nil {
t.Fatalf("push new branch %s update to %s: %v", branch, remoteDir, err)
}
}
func findBranchSpec(branches []testBranchSpec, name string) (testBranchSpec, bool) {
for _, branch := range branches {
if branch.name == name {
return branch, true
}
}
return testBranchSpec{}, false
}
func assertRepositoryBranchAndContents(t *testing.T, repoDir, branch, wantContents string) {
t.Helper()
repo, err := git.PlainOpen(repoDir)
if err != nil {
t.Fatalf("open local repo: %v", err)
}
head, err := repo.Head()
if err != nil {
t.Fatalf("local repo head: %v", err)
}
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("local head branch = %s, want %s", got, want)
}
contents, err := os.ReadFile(filepath.Join(repoDir, "branch.txt"))
if err != nil {
t.Fatalf("read branch marker: %v", err)
}
if got := string(contents); got != wantContents {
t.Fatalf("branch marker contents = %q, want %q", got, wantContents)
}
}
func assertRepositoryHeadBranch(t *testing.T, repoDir, branch string) {
t.Helper()
repo, err := git.PlainOpen(repoDir)
if err != nil {
t.Fatalf("open local repo: %v", err)
}
head, err := repo.Head()
if err != nil {
t.Fatalf("local repo head: %v", err)
}
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("local head branch = %s, want %s", got, want)
}
}
func assertRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
head, err := remoteRepo.Reference(plumbing.HEAD, false)
if err != nil {
t.Fatalf("read remote HEAD: %v", err)
}
if got, want := head.Target(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("remote HEAD target = %s, want %s", got, want)
}
}
func setRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(branch))); err != nil {
t.Fatalf("set remote HEAD to %s: %v", branch, err)
}
}
func assertRemoteBranchExistsWithCommit(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
if err != nil {
t.Fatalf("read remote branch %s: %v", branch, err)
}
if got := ref.Hash(); got == plumbing.ZeroHash {
t.Fatalf("remote branch %s hash = %s, want non-zero hash", branch, got)
}
}
func assertRemoteBranchDoesNotExist(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if _, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false); err == nil {
t.Fatalf("remote branch %s exists, want missing", branch)
} else if err != plumbing.ErrReferenceNotFound {
t.Fatalf("read remote branch %s: %v", branch, err)
}
}
func assertRemoteBranchContents(t *testing.T, remoteDir, branch, wantContents string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
if err != nil {
t.Fatalf("read remote branch %s: %v", branch, err)
}
commit, err := remoteRepo.CommitObject(ref.Hash())
if err != nil {
t.Fatalf("read remote branch %s commit: %v", branch, err)
}
tree, err := commit.Tree()
if err != nil {
t.Fatalf("read remote branch %s tree: %v", branch, err)
}
file, err := tree.File("branch.txt")
if err != nil {
t.Fatalf("read remote branch %s file: %v", branch, err)
}
contents, err := file.Contents()
if err != nil {
t.Fatalf("read remote branch %s contents: %v", branch, err)
}
if contents != wantContents {
t.Fatalf("remote branch %s contents = %q, want %q", branch, contents, wantContents)
}
}

View File

@@ -595,6 +595,7 @@ func (s *ObjectTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Aut
LastRefreshedAt: time.Time{}, LastRefreshedAt: time.Time{},
NextRefreshAfter: time.Time{}, NextRefreshAfter: time.Time{},
} }
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
return auth, nil return auth, nil
} }

View File

@@ -310,6 +310,7 @@ func (s *PostgresStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error)
LastRefreshedAt: time.Time{}, LastRefreshedAt: time.Time{},
NextRefreshAfter: time.Time{}, NextRefreshAfter: time.Time{},
} }
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
auths = append(auths, auth) auths = append(auths, auth)
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {

View File

@@ -330,32 +330,45 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
} }
// Reorder parts for 'model' role to ensure thinking block is first // Reorder parts for 'model' role:
// 1. Thinking parts first (Antigravity API requirement)
// 2. Regular parts (text, inlineData, etc.)
// 3. FunctionCall parts last
//
// Moving functionCall parts to the end prevents tool_use↔tool_result
// pairing breakage: the Antigravity API internally splits model messages
// at functionCall boundaries. If a text part follows a functionCall, the
// split creates an extra assistant turn between tool_use and tool_result,
// which Claude rejects with "tool_use ids were found without tool_result
// blocks immediately after".
if role == "model" { if role == "model" {
partsResult := gjson.GetBytes(clientContentJSON, "parts") partsResult := gjson.GetBytes(clientContentJSON, "parts")
if partsResult.IsArray() { if partsResult.IsArray() {
parts := partsResult.Array() parts := partsResult.Array()
var thinkingParts []gjson.Result if len(parts) > 1 {
var otherParts []gjson.Result var thinkingParts []gjson.Result
for _, part := range parts { var regularParts []gjson.Result
if part.Get("thought").Bool() { var functionCallParts []gjson.Result
thinkingParts = append(thinkingParts, part) for _, part := range parts {
} else { if part.Get("thought").Bool() {
otherParts = append(otherParts, part) thinkingParts = append(thinkingParts, part)
} } else if part.Get("functionCall").Exists() {
} functionCallParts = append(functionCallParts, part)
if len(thinkingParts) > 0 { } else {
firstPartIsThinking := parts[0].Get("thought").Bool() regularParts = append(regularParts, part)
if !firstPartIsThinking || len(thinkingParts) > 1 {
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
} }
for _, p := range otherParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
} }
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
}
for _, p := range regularParts {
newParts = append(newParts, p.Value())
}
for _, p := range functionCallParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
} }
} }
} }

View File

@@ -361,6 +361,167 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
} }
} }
func TestConvertClaudeRequestToAntigravity_ReorderTextAfterFunctionCall(t *testing.T) {
// Bug: text part after tool_use in an assistant message causes Antigravity
// to split at functionCall boundary, creating an extra assistant turn that
// breaks tool_use↔tool_result adjacency (upstream issue #989).
// Fix: reorder parts so functionCall comes last.
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me check..."},
{
"type": "tool_use",
"id": "call_abc",
"name": "Read",
"input": {"file": "test.go"}
},
{"type": "text", "text": "Reading the file now"}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_abc",
"content": "file content"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 3 {
t.Fatalf("Expected 3 parts, got %d", len(parts))
}
// Text parts should come before functionCall
if parts[0].Get("text").String() != "Let me check..." {
t.Errorf("Expected first text part first, got %s", parts[0].Raw)
}
if parts[1].Get("text").String() != "Reading the file now" {
t.Errorf("Expected second text part second, got %s", parts[1].Raw)
}
if !parts[2].Get("functionCall").Exists() {
t.Errorf("Expected functionCall last, got %s", parts[2].Raw)
}
if parts[2].Get("functionCall.name").String() != "Read" {
t.Errorf("Expected functionCall name 'Read', got '%s'", parts[2].Get("functionCall.name").String())
}
}
func TestConvertClaudeRequestToAntigravity_ReorderParallelFunctionCalls(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Reading both files."},
{
"type": "tool_use",
"id": "call_1",
"name": "Read",
"input": {"file": "a.go"}
},
{"type": "text", "text": "And this one too."},
{
"type": "tool_use",
"id": "call_2",
"name": "Read",
"input": {"file": "b.go"}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 4 {
t.Fatalf("Expected 4 parts, got %d", len(parts))
}
if parts[0].Get("text").String() != "Reading both files." {
t.Errorf("Expected first text, got %s", parts[0].Raw)
}
if parts[1].Get("text").String() != "And this one too." {
t.Errorf("Expected second text, got %s", parts[1].Raw)
}
if parts[2].Get("functionCall.name").String() != "Read" || parts[2].Get("functionCall.id").String() != "call_1" {
t.Errorf("Expected fc1 third, got %s", parts[2].Raw)
}
if parts[3].Get("functionCall.name").String() != "Read" || parts[3].Get("functionCall.id").String() != "call_2" {
t.Errorf("Expected fc2 fourth, got %s", parts[3].Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ReorderThinkingAndTextBeforeFunctionCall(t *testing.T) {
cache.ClearSignatureCache("")
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think about this..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Before thinking"},
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
{
"type": "tool_use",
"id": "call_xyz",
"name": "Bash",
"input": {"command": "ls"}
},
{"type": "text", "text": "After tool call"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// contents.1 = assistant message (contents.0 = user)
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
if len(parts) != 4 {
t.Fatalf("Expected 4 parts, got %d", len(parts))
}
// Order: thinking → text → text → functionCall
if !parts[0].Get("thought").Bool() {
t.Error("First part should be thinking")
}
if parts[1].Get("functionCall").Exists() || parts[1].Get("thought").Bool() {
t.Errorf("Second part should be text, got %s", parts[1].Raw)
}
if parts[2].Get("functionCall").Exists() || parts[2].Get("thought").Bool() {
t.Errorf("Third part should be text, got %s", parts[2].Raw)
}
if !parts[3].Get("functionCall").Exists() {
t.Errorf("Last part should be functionCall, got %s", parts[3].Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) { func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{ inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620", "model": "claude-3-5-sonnet-20240620",

View File

@@ -26,6 +26,11 @@ type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool HasToolCall bool
BlockIndex int BlockIndex int
HasReceivedArgumentsDelta bool HasReceivedArgumentsDelta bool
HasTextDelta bool
TextBlockOpen bool
ThinkingBlockOpen bool
ThinkingStopPending bool
ThinkingSignature string
} }
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion. // ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
@@ -44,7 +49,7 @@ type ConvertCodexResponseToClaudeParams struct {
// //
// Returns: // Returns:
// - [][]byte: A slice of Claude Code-compatible JSON responses // - [][]byte: A slice of Claude Code-compatible JSON responses
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &ConvertCodexResponseToClaudeParams{ *param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false, HasToolCall: false,
@@ -52,7 +57,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
} }
} }
// log.Debugf("rawJSON: %s", string(rawJSON))
if !bytes.HasPrefix(rawJSON, dataTag) { if !bytes.HasPrefix(rawJSON, dataTag) {
return [][]byte{} return [][]byte{}
} }
@@ -60,9 +64,18 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output := make([]byte, 0, 512) output := make([]byte, 0, 512)
rootResult := gjson.ParseBytes(rawJSON) rootResult := gjson.ParseBytes(rawJSON)
params := (*param).(*ConvertCodexResponseToClaudeParams)
if params.ThinkingBlockOpen && params.ThinkingStopPending {
switch rootResult.Get("type").String() {
case "response.content_part.added", "response.completed":
output = append(output, finalizeCodexThinkingBlock(params)...)
}
}
typeResult := rootResult.Get("type") typeResult := rootResult.Get("type")
typeStr := typeResult.String() typeStr := typeResult.String()
var template []byte var template []byte
if typeStr == "response.created" { if typeStr == "response.created" {
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`) template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String()) template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
@@ -70,43 +83,49 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
} else if typeStr == "response.reasoning_summary_part.added" { } else if typeStr == "response.reasoning_summary_part.added" {
if params.ThinkingBlockOpen && params.ThinkingStopPending {
output = append(output, finalizeCodexThinkingBlock(params)...)
}
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`) template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.ThinkingBlockOpen = true
params.ThinkingStopPending = false
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.reasoning_summary_text.delta" { } else if typeStr == "response.reasoning_summary_text.delta" {
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`) template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String()) template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.reasoning_summary_part.done" { } else if typeStr == "response.reasoning_summary_part.done" {
template = []byte(`{"type":"content_block_stop","index":0}`) params.ThinkingStopPending = true
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) if params.ThinkingSignature != "" {
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ output = append(output, finalizeCodexThinkingBlock(params)...)
}
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if typeStr == "response.content_part.added" { } else if typeStr == "response.content_part.added" {
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`) template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.TextBlockOpen = true
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.output_text.delta" { } else if typeStr == "response.output_text.delta" {
params.HasTextDelta = true
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`) template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String()) template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.content_part.done" { } else if typeStr == "response.content_part.done" {
template = []byte(`{"type":"content_block_stop","index":0}`) template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ params.TextBlockOpen = false
params.BlockIndex++
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if typeStr == "response.completed" { } else if typeStr == "response.completed" {
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`) template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall p := params.HasToolCall
stopReason := rootResult.Get("response.stop_reason").String() stopReason := rootResult.Get("response.stop_reason").String()
if p { if p {
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use") template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
@@ -128,13 +147,13 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemResult := rootResult.Get("item") itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String() itemType := itemResult.Get("type").String()
if itemType == "function_call" { if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true output = append(output, finalizeCodexThinkingBlock(params)...)
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false params.HasToolCall = true
params.HasReceivedArgumentsDelta = false
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`) template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String())) template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{ {
// Restore original tool name if shortened
name := itemResult.Get("name").String() name := itemResult.Get("name").String()
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON) rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
if orig, ok := rev[name]; ok { if orig, ok := rev[name]; ok {
@@ -146,37 +165,85 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if itemType == "reasoning" {
params.ThinkingSignature = itemResult.Get("encrypted_content").String()
if params.ThinkingStopPending {
output = append(output, finalizeCodexThinkingBlock(params)...)
}
} }
} else if typeStr == "response.output_item.done" { } else if typeStr == "response.output_item.done" {
itemResult := rootResult.Get("item") itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String() itemType := itemResult.Get("type").String()
if itemType == "function_call" { if itemType == "message" {
if params.HasTextDelta {
return [][]byte{output}
}
contentResult := itemResult.Get("content")
if !contentResult.Exists() || !contentResult.IsArray() {
return [][]byte{output}
}
var textBuilder strings.Builder
contentResult.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() != "output_text" {
return true
}
if txt := part.Get("text").String(); txt != "" {
textBuilder.WriteString(txt)
}
return true
})
text := textBuilder.String()
if text == "" {
return [][]byte{output}
}
output = append(output, finalizeCodexThinkingBlock(params)...)
if !params.TextBlockOpen {
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.TextBlockOpen = true
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
}
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.text", text)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
template = []byte(`{"type":"content_block_stop","index":0}`) template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ params.TextBlockOpen = false
params.BlockIndex++
params.HasTextDelta = true
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if itemType == "function_call" {
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.BlockIndex++
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if itemType == "reasoning" {
if signature := itemResult.Get("encrypted_content").String(); signature != "" {
params.ThinkingSignature = signature
}
output = append(output, finalizeCodexThinkingBlock(params)...)
params.ThinkingSignature = ""
} }
} else if typeStr == "response.function_call_arguments.delta" { } else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true params.HasReceivedArgumentsDelta = true
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String()) template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.function_call_arguments.done" { } else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments if !params.HasReceivedArgumentsDelta {
// in a single "done" event without preceding "delta" events.
// Emit the full arguments as a single input_json_delta so the
// downstream Claude client receives the complete tool input.
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" { if args := rootResult.Get("arguments").String(); args != "" {
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", args) template, _ = sjson.SetBytes(template, "delta.partial_json", args)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
@@ -191,15 +258,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// This function processes the complete Codex response and transforms it into a single Claude Code-compatible // This function processes the complete Codex response and transforms it into a single Claude Code-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all // JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the Claude Code API format. // the information into a single response that matches the Claude Code API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte { func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON) revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
@@ -230,6 +288,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
switch item.Get("type").String() { switch item.Get("type").String() {
case "reasoning": case "reasoning":
thinkingBuilder := strings.Builder{} thinkingBuilder := strings.Builder{}
signature := item.Get("encrypted_content").String()
if summary := item.Get("summary"); summary.Exists() { if summary := item.Get("summary"); summary.Exists() {
if summary.IsArray() { if summary.IsArray() {
summary.ForEach(func(_, part gjson.Result) bool { summary.ForEach(func(_, part gjson.Result) bool {
@@ -260,9 +319,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} }
} }
} }
if thinkingBuilder.Len() > 0 { if thinkingBuilder.Len() > 0 || signature != "" {
block := []byte(`{"type":"thinking","thinking":""}`) block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String()) block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
if signature != "" {
block, _ = sjson.SetBytes(block, "signature", signature)
}
out, _ = sjson.SetRawBytes(out, "content.-1", block) out, _ = sjson.SetRawBytes(out, "content.-1", block)
} }
case "message": case "message":
@@ -371,6 +433,30 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
return rev return rev
} }
func ClaudeTokenCount(ctx context.Context, count int64) []byte { func ClaudeTokenCount(_ context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count) return translatorcommon.ClaudeInputTokensJSON(count)
} }
func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []byte {
if !params.ThinkingBlockOpen {
return nil
}
output := make([]byte, 0, 256)
if params.ThinkingSignature != "" {
signatureDelta := []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":""}}`)
signatureDelta, _ = sjson.SetBytes(signatureDelta, "index", params.BlockIndex)
signatureDelta, _ = sjson.SetBytes(signatureDelta, "delta.signature", params.ThinkingSignature)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", signatureDelta, 2)
}
contentBlockStop := []byte(`{"type":"content_block_stop","index":0}`)
contentBlockStop, _ = sjson.SetBytes(contentBlockStop, "index", params.BlockIndex)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", contentBlockStop, 2)
params.BlockIndex++
params.ThinkingBlockOpen = false
params.ThinkingStopPending = false
return output
}

View File

@@ -0,0 +1,319 @@
package claude
import (
"context"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertCodexResponseToClaude_StreamThinkingIncludesSignature(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"model\":\"gpt-5\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_123\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
startFound := false
signatureDeltaFound := false
stopFound := false
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
switch data.Get("type").String() {
case "content_block_start":
if data.Get("content_block.type").String() == "thinking" {
startFound = true
if data.Get("content_block.signature").Exists() {
t.Fatalf("thinking start block should NOT have signature field when signature is unknown: %s", line)
}
}
case "content_block_delta":
if data.Get("delta.type").String() == "signature_delta" {
signatureDeltaFound = true
if got := data.Get("delta.signature").String(); got != "enc_sig_123" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
case "content_block_stop":
stopFound = true
}
}
}
if !startFound {
t.Fatal("expected thinking content_block_start event")
}
if !signatureDeltaFound {
t.Fatal("expected signature_delta event for thinking block")
}
if !stopFound {
t.Fatal("expected content_block_stop event for thinking block")
}
}
func TestConvertCodexResponseToClaude_StreamThinkingWithoutReasoningItemStillIncludesSignatureField(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
thinkingStartFound := false
thinkingStopFound := false
signatureDeltaFound := false
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
thinkingStartFound = true
if data.Get("content_block.signature").Exists() {
t.Fatalf("thinking start block should NOT have signature field without encrypted_content: %s", line)
}
}
if data.Get("type").String() == "content_block_stop" && data.Get("index").Int() == 0 {
thinkingStopFound = true
}
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaFound = true
}
}
}
if !thinkingStartFound {
t.Fatal("expected thinking content_block_start event")
}
if !thinkingStopFound {
t.Fatal("expected thinking content_block_stop event")
}
if signatureDeltaFound {
t.Fatal("did not expect signature_delta without encrypted_content")
}
}
func TestConvertCodexResponseToClaude_StreamThinkingFinalizesPendingBlockBeforeNextSummaryPart(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
startCount := 0
stopCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
startCount++
}
if data.Get("type").String() == "content_block_stop" {
stopCount++
}
}
}
if startCount != 2 {
t.Fatalf("expected 2 thinking block starts, got %d", startCount)
}
if stopCount != 1 {
t.Fatalf("expected pending thinking block to be finalized before second start, got %d stops", stopCount)
}
}
func TestConvertCodexResponseToClaude_StreamThinkingRetainsSignatureAcrossMultipartReasoning(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_multipart\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Second part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
signatureDeltaCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaCount++
if got := data.Get("delta.signature").String(); got != "enc_sig_multipart" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
}
}
if signatureDeltaCount != 2 {
t.Fatalf("expected signature_delta for both multipart thinking blocks, got %d", signatureDeltaCount)
}
}
func TestConvertCodexResponseToClaude_StreamThinkingUsesEarlyCapturedSignatureWhenDoneOmitsIt(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_early\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
signatureDeltaCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaCount++
if got := data.Get("delta.signature").String(); got != "enc_sig_early" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
}
}
if signatureDeltaCount != 1 {
t.Fatalf("expected signature_delta from early-captured signature, got %d", signatureDeltaCount)
}
}
func TestConvertCodexResponseToClaudeNonStream_ThinkingIncludesSignature(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
response := []byte(`{
"type":"response.completed",
"response":{
"id":"resp_123",
"model":"gpt-5",
"usage":{"input_tokens":10,"output_tokens":20},
"output":[
{
"type":"reasoning",
"encrypted_content":"enc_sig_nonstream",
"summary":[{"type":"summary_text","text":"internal reasoning"}]
},
{
"type":"message",
"content":[{"type":"output_text","text":"final answer"}]
}
]
}
}`)
out := ConvertCodexResponseToClaudeNonStream(ctx, "", originalRequest, nil, response, nil)
parsed := gjson.ParseBytes(out)
thinking := parsed.Get("content.0")
if thinking.Get("type").String() != "thinking" {
t.Fatalf("expected first content block to be thinking, got %s", thinking.Raw)
}
if got := thinking.Get("signature").String(); got != "enc_sig_nonstream" {
t.Fatalf("expected signature to be preserved, got %q", got)
}
if got := thinking.Get("thinking").String(); got != "internal reasoning" {
t.Fatalf("unexpected thinking text: %q", got)
}
}
func TestConvertCodexResponseToClaude_StreamEmptyOutputUsesOutputItemDoneMessageFallback(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"tools":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5\"}}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}"),
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
foundText := false
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "text_delta" && data.Get("delta.text").String() == "ok" {
foundText = true
break
}
}
if foundText {
break
}
}
if !foundText {
t.Fatalf("expected fallback content from response.output_item.done message; outputs=%q", outputs)
}
}

View File

@@ -20,10 +20,11 @@ var (
// ConvertCodexResponseToGeminiParams holds parameters for response conversion. // ConvertCodexResponseToGeminiParams holds parameters for response conversion.
type ConvertCodexResponseToGeminiParams struct { type ConvertCodexResponseToGeminiParams struct {
Model string Model string
CreatedAt int64 CreatedAt int64
ResponseID string ResponseID string
LastStorageOutput []byte LastStorageOutput []byte
HasOutputTextDelta bool
} }
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format. // ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
@@ -42,10 +43,11 @@ type ConvertCodexResponseToGeminiParams struct {
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &ConvertCodexResponseToGeminiParams{ *param = &ConvertCodexResponseToGeminiParams{
Model: modelName, Model: modelName,
CreatedAt: 0, CreatedAt: 0,
ResponseID: "", ResponseID: "",
LastStorageOutput: nil, LastStorageOutput: nil,
HasOutputTextDelta: false,
} }
} }
@@ -58,18 +60,18 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
typeResult := rootResult.Get("type") typeResult := rootResult.Get("type")
typeStr := typeResult.String() typeStr := typeResult.String()
params := (*param).(*ConvertCodexResponseToGeminiParams)
// Base Gemini response template // Base Gemini response template
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`) template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`)
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 && typeStr == "response.output_item.done" { {
template = append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...) template, _ = sjson.SetBytes(template, "modelVersion", params.Model)
} else {
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
createdAtResult := rootResult.Get("response.created_at") createdAtResult := rootResult.Get("response.created_at")
if createdAtResult.Exists() { if createdAtResult.Exists() {
(*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int() params.CreatedAt = createdAtResult.Int()
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano)) template, _ = sjson.SetBytes(template, "createTime", time.Unix(params.CreatedAt, 0).Format(time.RFC3339Nano))
} }
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID) template, _ = sjson.SetBytes(template, "responseId", params.ResponseID)
} }
// Handle function call completion // Handle function call completion
@@ -101,7 +103,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...) params.LastStorageOutput = append([]byte(nil), template...)
// Use this return to storage message // Use this return to storage message
return [][]byte{} return [][]byte{}
@@ -111,15 +113,45 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
if typeStr == "response.created" { // Handle response creation - set model and response ID if typeStr == "response.created" { // Handle response creation - set model and response ID
template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String()) template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String())
template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String()) template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String())
(*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String() params.ResponseID = rootResult.Get("response.id").String()
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta } else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
part := []byte(`{"thought":true,"text":""}`) part := []byte(`{"thought":true,"text":""}`)
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String()) part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta } else if typeStr == "response.output_text.delta" { // Handle regular text content delta
params.HasOutputTextDelta = true
part := []byte(`{"text":""}`) part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String()) part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.output_item.done" { // Fallback: emit final message text when no delta chunks were received
itemResult := rootResult.Get("item")
if itemResult.Get("type").String() != "message" || params.HasOutputTextDelta {
return [][]byte{}
}
contentResult := itemResult.Get("content")
if !contentResult.Exists() || !contentResult.IsArray() {
return [][]byte{}
}
wroteText := false
contentResult.ForEach(func(_, partResult gjson.Result) bool {
if partResult.Get("type").String() != "output_text" {
return true
}
text := partResult.Get("text").String()
if text == "" {
return true
}
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", text)
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
wroteText = true
return true
})
if wroteText {
params.HasOutputTextDelta = true
return [][]byte{template}
}
return [][]byte{}
} else if typeStr == "response.completed" { // Handle response completion with usage metadata } else if typeStr == "response.completed" { // Handle response completion with usage metadata
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int()) template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int()) template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
@@ -129,11 +161,10 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
return [][]byte{} return [][]byte{}
} }
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 { if len(params.LastStorageOutput) > 0 {
return [][]byte{ stored := append([]byte(nil), params.LastStorageOutput...)
append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...), params.LastStorageOutput = nil
template, return [][]byte{stored, template}
}
} }
return [][]byte{template} return [][]byte{template}
} }

View File

@@ -0,0 +1,35 @@
package gemini
import (
"context"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertCodexResponseToGemini_StreamEmptyOutputUsesOutputItemDoneMessageFallback(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"tools":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}"),
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, chunk, &param)...)
}
found := false
for _, out := range outputs {
if gjson.GetBytes(out, "candidates.0.content.parts.0.text").String() == "ok" {
found = true
break
}
}
if !found {
t.Fatalf("expected fallback content from response.output_item.done message; outputs=%q", outputs)
}
}

View File

@@ -284,12 +284,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
} }
// Process the output array for content and function calls // Process the output array for content and function calls
var toolCalls [][]byte
outputResult := responseResult.Get("output") outputResult := responseResult.Get("output")
if outputResult.IsArray() { if outputResult.IsArray() {
outputArray := outputResult.Array() outputArray := outputResult.Array()
var contentText string var contentText string
var reasoningText string var reasoningText string
var toolCalls [][]byte
for _, outputItem := range outputArray { for _, outputItem := range outputArray {
outputType := outputItem.Get("type").String() outputType := outputItem.Get("type").String()
@@ -367,8 +367,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if statusResult := responseResult.Get("status"); statusResult.Exists() { if statusResult := responseResult.Get("status"); statusResult.Exists() {
status := statusResult.String() status := statusResult.String()
if status == "completed" { if status == "completed" {
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop") finishReason := "stop"
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop") if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
} }
} }

View File

@@ -6,7 +6,7 @@
package claude package claude
import ( import (
"bytes" "fmt"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
@@ -31,8 +31,6 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
// - []byte: The transformed request in Gemini CLI format. // - []byte: The transformed request in Gemini CLI format.
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON rawJSON := inputRawJSON
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON // Build output Gemini CLI request JSON
out := []byte(`{"contents":[]}`) out := []byte(`{"contents":[]}`)
out, _ = sjson.SetBytes(out, "model", modelName) out, _ = sjson.SetBytes(out, "model", modelName)
@@ -146,13 +144,37 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}) })
} }
// strip trailing model turn with unanswered function calls —
// Gemini returns empty responses when the last turn is a model
// functionCall with no corresponding user functionResponse.
contents := gjson.GetBytes(out, "contents")
if contents.Exists() && contents.IsArray() {
arr := contents.Array()
if len(arr) > 0 {
last := arr[len(arr)-1]
if last.Get("role").String() == "model" {
hasFC := false
last.Get("parts").ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
hasFC = true
return false
}
return true
})
if hasFC {
out, _ = sjson.DeleteBytes(out, fmt.Sprintf("contents.%d", len(arr)-1))
}
}
}
}
// tools // tools
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() { if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
hasTools := false hasTools := false
toolsResult.ForEach(func(_, toolResult gjson.Result) bool { toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
inputSchemaResult := toolResult.Get("input_schema") inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
tool := []byte(toolResult.Raw) tool := []byte(toolResult.Raw)
var err error var err error
tool, err = sjson.DeleteBytes(tool, "input_schema") tool, err = sjson.DeleteBytes(tool, "input_schema")
@@ -168,6 +190,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
tool, _ = sjson.DeleteBytes(tool, "type") tool, _ = sjson.DeleteBytes(tool, "type")
tool, _ = sjson.DeleteBytes(tool, "cache_control") tool, _ = sjson.DeleteBytes(tool, "cache_control")
tool, _ = sjson.DeleteBytes(tool, "defer_loading") tool, _ = sjson.DeleteBytes(tool, "defer_loading")
tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming")
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String())) tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() { if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
if !hasTools { if !hasTools {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"sort"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -16,29 +17,35 @@ import (
type oaiToResponsesStateReasoning struct { type oaiToResponsesStateReasoning struct {
ReasoningID string ReasoningID string
ReasoningData string ReasoningData string
OutputIndex int
} }
type oaiToResponsesState struct { type oaiToResponsesState struct {
Seq int Seq int
ResponseID string ResponseID string
Created int64 Created int64
Started bool Started bool
ReasoningID string CompletionPending bool
ReasoningIndex int CompletedEmitted bool
ReasoningID string
ReasoningIndex int
// aggregation buffers for response.output // aggregation buffers for response.output
// Per-output message text buffers by index // Per-output message text buffers by index
MsgTextBuf map[int]*strings.Builder MsgTextBuf map[int]*strings.Builder
ReasoningBuf strings.Builder ReasoningBuf strings.Builder
Reasonings []oaiToResponsesStateReasoning Reasonings []oaiToResponsesStateReasoning
FuncArgsBuf map[int]*strings.Builder // index -> args FuncArgsBuf map[string]*strings.Builder
FuncNames map[int]string // index -> name FuncNames map[string]string
FuncCallIDs map[int]string // index -> call_id FuncCallIDs map[string]string
FuncOutputIx map[string]int
MsgOutputIx map[int]int
NextOutputIx int
// message item state per output index // message item state per output index
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
MsgItemDone map[int]bool // whether message done events were emitted MsgItemDone map[int]bool // whether message done events were emitted
// function item done state // function item done state
FuncArgsDone map[int]bool FuncArgsDone map[string]bool
FuncItemDone map[int]bool FuncItemDone map[string]bool
// usage aggregation // usage aggregation
PromptTokens int64 PromptTokens int64
CachedTokens int64 CachedTokens int64
@@ -55,20 +62,157 @@ func emitRespEvent(event string, payload []byte) []byte {
return translatorcommon.SSEEventData(event, payload) return translatorcommon.SSEEventData(event, payload)
} }
func buildResponsesCompletedEvent(st *oaiToResponsesState, requestRawJSON []byte, nextSeq func() int) []byte {
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
}
}
outputsWrapper := []byte(`{"arr":[]}`)
type completedOutputItem struct {
index int
raw []byte
}
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
if len(st.Reasonings) > 0 {
for _, r := range st.Reasonings {
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
}
}
if len(st.MsgItemAdded) > 0 {
for i := range st.MsgItemAdded {
txt := ""
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
}
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.SetBytes(item, "content.0.text", txt)
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
}
}
if len(st.FuncArgsBuf) > 0 {
for key := range st.FuncArgsBuf {
args := ""
if b := st.FuncArgsBuf[key]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[key]
name := st.FuncNames[key]
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
}
}
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
for _, item := range outputItems {
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
}
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
if st.UsageSeen {
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
if st.ReasoningTokens > 0 {
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
}
total := st.TotalTokens
if total == 0 {
total = st.PromptTokens + st.CompletionTokens
}
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
}
return emitRespEvent("response.completed", completed)
}
// ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks // ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*). // to OpenAI Responses SSE events (response.*).
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &oaiToResponsesState{ *param = &oaiToResponsesState{
FuncArgsBuf: make(map[int]*strings.Builder), FuncArgsBuf: make(map[string]*strings.Builder),
FuncNames: make(map[int]string), FuncNames: make(map[string]string),
FuncCallIDs: make(map[int]string), FuncCallIDs: make(map[string]string),
FuncOutputIx: make(map[string]int),
MsgOutputIx: make(map[int]int),
MsgTextBuf: make(map[int]*strings.Builder), MsgTextBuf: make(map[int]*strings.Builder),
MsgItemAdded: make(map[int]bool), MsgItemAdded: make(map[int]bool),
MsgContentAdded: make(map[int]bool), MsgContentAdded: make(map[int]bool),
MsgItemDone: make(map[int]bool), MsgItemDone: make(map[int]bool),
FuncArgsDone: make(map[int]bool), FuncArgsDone: make(map[string]bool),
FuncItemDone: make(map[int]bool), FuncItemDone: make(map[string]bool),
Reasonings: make([]oaiToResponsesStateReasoning, 0), Reasonings: make([]oaiToResponsesStateReasoning, 0),
} }
} }
@@ -83,6 +227,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
return [][]byte{} return [][]byte{}
} }
if bytes.Equal(rawJSON, []byte("[DONE]")) { if bytes.Equal(rawJSON, []byte("[DONE]")) {
if st.CompletionPending && !st.CompletedEmitted {
st.CompletedEmitted = true
return [][]byte{buildResponsesCompletedEvent(st, requestRawJSON, func() int { st.Seq++; return st.Seq })}
}
return [][]byte{} return [][]byte{}
} }
@@ -125,6 +273,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
} }
nextSeq := func() int { st.Seq++; return st.Seq } nextSeq := func() int { st.Seq++; return st.Seq }
allocOutputIndex := func() int {
ix := st.NextOutputIx
st.NextOutputIx++
return ix
}
toolStateKey := func(outputIndex, toolIndex int) string { return fmt.Sprintf("%d:%d", outputIndex, toolIndex) }
var out [][]byte var out [][]byte
if !st.Started { if !st.Started {
@@ -135,20 +289,25 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
st.ReasoningBuf.Reset() st.ReasoningBuf.Reset()
st.ReasoningID = "" st.ReasoningID = ""
st.ReasoningIndex = 0 st.ReasoningIndex = 0
st.FuncArgsBuf = make(map[int]*strings.Builder) st.FuncArgsBuf = make(map[string]*strings.Builder)
st.FuncNames = make(map[int]string) st.FuncNames = make(map[string]string)
st.FuncCallIDs = make(map[int]string) st.FuncCallIDs = make(map[string]string)
st.FuncOutputIx = make(map[string]int)
st.MsgOutputIx = make(map[int]int)
st.NextOutputIx = 0
st.MsgItemAdded = make(map[int]bool) st.MsgItemAdded = make(map[int]bool)
st.MsgContentAdded = make(map[int]bool) st.MsgContentAdded = make(map[int]bool)
st.MsgItemDone = make(map[int]bool) st.MsgItemDone = make(map[int]bool)
st.FuncArgsDone = make(map[int]bool) st.FuncArgsDone = make(map[string]bool)
st.FuncItemDone = make(map[int]bool) st.FuncItemDone = make(map[string]bool)
st.PromptTokens = 0 st.PromptTokens = 0
st.CachedTokens = 0 st.CachedTokens = 0
st.CompletionTokens = 0 st.CompletionTokens = 0
st.TotalTokens = 0 st.TotalTokens = 0
st.ReasoningTokens = 0 st.ReasoningTokens = 0
st.UsageSeen = false st.UsageSeen = false
st.CompletionPending = false
st.CompletedEmitted = false
// response.created // response.created
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`) created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq()) created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
@@ -185,7 +344,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text) outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text)
out = append(out, emitRespEvent("response.output_item.done", outputItemDone)) out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text}) st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text, OutputIndex: st.ReasoningIndex})
st.ReasoningID = "" st.ReasoningID = ""
} }
@@ -201,10 +360,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
stopReasoning(st.ReasoningBuf.String()) stopReasoning(st.ReasoningBuf.String())
st.ReasoningBuf.Reset() st.ReasoningBuf.Reset()
} }
if _, exists := st.MsgOutputIx[idx]; !exists {
st.MsgOutputIx[idx] = allocOutputIndex()
}
msgOutputIndex := st.MsgOutputIx[idx]
if !st.MsgItemAdded[idx] { if !st.MsgItemAdded[idx] {
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`) item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq()) item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "output_index", idx) item, _ = sjson.SetBytes(item, "output_index", msgOutputIndex)
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
out = append(out, emitRespEvent("response.output_item.added", item)) out = append(out, emitRespEvent("response.output_item.added", item))
st.MsgItemAdded[idx] = true st.MsgItemAdded[idx] = true
@@ -213,7 +376,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`) part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq()) part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
part, _ = sjson.SetBytes(part, "output_index", idx) part, _ = sjson.SetBytes(part, "output_index", msgOutputIndex)
part, _ = sjson.SetBytes(part, "content_index", 0) part, _ = sjson.SetBytes(part, "content_index", 0)
out = append(out, emitRespEvent("response.content_part.added", part)) out = append(out, emitRespEvent("response.content_part.added", part))
st.MsgContentAdded[idx] = true st.MsgContentAdded[idx] = true
@@ -222,7 +385,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`) msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq()) msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
msg, _ = sjson.SetBytes(msg, "output_index", idx) msg, _ = sjson.SetBytes(msg, "output_index", msgOutputIndex)
msg, _ = sjson.SetBytes(msg, "content_index", 0) msg, _ = sjson.SetBytes(msg, "content_index", 0)
msg, _ = sjson.SetBytes(msg, "delta", c.String()) msg, _ = sjson.SetBytes(msg, "delta", c.String())
out = append(out, emitRespEvent("response.output_text.delta", msg)) out = append(out, emitRespEvent("response.output_text.delta", msg))
@@ -238,10 +401,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
// On first appearance, add reasoning item and part // On first appearance, add reasoning item and part
if st.ReasoningID == "" { if st.ReasoningID == "" {
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx) st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
st.ReasoningIndex = idx st.ReasoningIndex = allocOutputIndex()
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`) item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq()) item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "output_index", idx) item, _ = sjson.SetBytes(item, "output_index", st.ReasoningIndex)
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID) item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID)
out = append(out, emitRespEvent("response.output_item.added", item)) out = append(out, emitRespEvent("response.output_item.added", item))
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`) part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
@@ -269,6 +432,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
// Before emitting any function events, if a message is open for this index, // Before emitting any function events, if a message is open for this index,
// close its text/content to match Codex expected ordering. // close its text/content to match Codex expected ordering.
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] { if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
msgOutputIndex := st.MsgOutputIx[idx]
fullText := "" fullText := ""
if b := st.MsgTextBuf[idx]; b != nil { if b := st.MsgTextBuf[idx]; b != nil {
fullText = b.String() fullText = b.String()
@@ -276,7 +440,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`) done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq()) done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
done, _ = sjson.SetBytes(done, "output_index", idx) done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
done, _ = sjson.SetBytes(done, "content_index", 0) done, _ = sjson.SetBytes(done, "content_index", 0)
done, _ = sjson.SetBytes(done, "text", fullText) done, _ = sjson.SetBytes(done, "text", fullText)
out = append(out, emitRespEvent("response.output_text.done", done)) out = append(out, emitRespEvent("response.output_text.done", done))
@@ -284,74 +448,78 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`) partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq()) partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
partDone, _ = sjson.SetBytes(partDone, "output_index", idx) partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
partDone, _ = sjson.SetBytes(partDone, "content_index", 0) partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText) partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
out = append(out, emitRespEvent("response.content_part.done", partDone)) out = append(out, emitRespEvent("response.content_part.done", partDone))
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`) itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq()) itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx) itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText) itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
out = append(out, emitRespEvent("response.output_item.done", itemDone)) out = append(out, emitRespEvent("response.output_item.done", itemDone))
st.MsgItemDone[idx] = true st.MsgItemDone[idx] = true
} }
// Only emit item.added once per tool call and preserve call_id across chunks. tcs.ForEach(func(_, tc gjson.Result) bool {
newCallID := tcs.Get("0.id").String() toolIndex := int(tc.Get("index").Int())
nameChunk := tcs.Get("0.function.name").String() key := toolStateKey(idx, toolIndex)
if nameChunk != "" { newCallID := tc.Get("id").String()
st.FuncNames[idx] = nameChunk nameChunk := tc.Get("function.name").String()
} if nameChunk != "" {
existingCallID := st.FuncCallIDs[idx] st.FuncNames[key] = nameChunk
effectiveCallID := existingCallID
shouldEmitItem := false
if existingCallID == "" && newCallID != "" {
// First time seeing a valid call_id for this index
effectiveCallID = newCallID
st.FuncCallIDs[idx] = newCallID
shouldEmitItem = true
}
if shouldEmitItem && effectiveCallID != "" {
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
o, _ = sjson.SetBytes(o, "output_index", idx)
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
name := st.FuncNames[idx]
o, _ = sjson.SetBytes(o, "item.name", name)
out = append(out, emitRespEvent("response.output_item.added", o))
}
// Ensure args buffer exists for this index
if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{}
}
// Append arguments delta if available and we have a valid call_id to reference
if args := tcs.Get("0.function.arguments"); args.Exists() && args.String() != "" {
// Prefer an already known call_id; fall back to newCallID if first time
refCallID := st.FuncCallIDs[idx]
if refCallID == "" {
refCallID = newCallID
} }
if refCallID != "" {
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`) existingCallID := st.FuncCallIDs[key]
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq()) effectiveCallID := existingCallID
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID)) shouldEmitItem := false
ad, _ = sjson.SetBytes(ad, "output_index", idx) if existingCallID == "" && newCallID != "" {
ad, _ = sjson.SetBytes(ad, "delta", args.String()) effectiveCallID = newCallID
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad)) st.FuncCallIDs[key] = newCallID
st.FuncOutputIx[key] = allocOutputIndex()
shouldEmitItem = true
} }
st.FuncArgsBuf[idx].WriteString(args.String())
} if shouldEmitItem && effectiveCallID != "" {
outputIndex := st.FuncOutputIx[key]
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
o, _ = sjson.SetBytes(o, "output_index", outputIndex)
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
o, _ = sjson.SetBytes(o, "item.name", st.FuncNames[key])
out = append(out, emitRespEvent("response.output_item.added", o))
}
if st.FuncArgsBuf[key] == nil {
st.FuncArgsBuf[key] = &strings.Builder{}
}
if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" {
refCallID := st.FuncCallIDs[key]
if refCallID == "" {
refCallID = newCallID
}
if refCallID != "" {
outputIndex := st.FuncOutputIx[key]
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
ad, _ = sjson.SetBytes(ad, "output_index", outputIndex)
ad, _ = sjson.SetBytes(ad, "delta", args.String())
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
}
st.FuncArgsBuf[key].WriteString(args.String())
}
return true
})
} }
} }
// finish_reason triggers finalization, including text done/content done/item done, // finish_reason triggers item-level finalization. response.completed is
// reasoning done/part.done, function args done/item done, and completed // deferred until the terminal [DONE] marker so late usage-only chunks can
// still populate response.usage.
if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" { if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
// Emit message done events for all indices that started a message // Emit message done events for all indices that started a message
if len(st.MsgItemAdded) > 0 { if len(st.MsgItemAdded) > 0 {
@@ -360,15 +528,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
for i := range st.MsgItemAdded { for i := range st.MsgItemAdded {
idxs = append(idxs, i) idxs = append(idxs, i)
} }
for i := 0; i < len(idxs); i++ { sort.Slice(idxs, func(i, j int) bool { return st.MsgOutputIx[idxs[i]] < st.MsgOutputIx[idxs[j]] })
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, i := range idxs { for _, i := range idxs {
if st.MsgItemAdded[i] && !st.MsgItemDone[i] { if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
msgOutputIndex := st.MsgOutputIx[i]
fullText := "" fullText := ""
if b := st.MsgTextBuf[i]; b != nil { if b := st.MsgTextBuf[i]; b != nil {
fullText = b.String() fullText = b.String()
@@ -376,7 +539,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`) done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq()) done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
done, _ = sjson.SetBytes(done, "output_index", i) done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
done, _ = sjson.SetBytes(done, "content_index", 0) done, _ = sjson.SetBytes(done, "content_index", 0)
done, _ = sjson.SetBytes(done, "text", fullText) done, _ = sjson.SetBytes(done, "text", fullText)
out = append(out, emitRespEvent("response.output_text.done", done)) out = append(out, emitRespEvent("response.output_text.done", done))
@@ -384,14 +547,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`) partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq()) partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
partDone, _ = sjson.SetBytes(partDone, "output_index", i) partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
partDone, _ = sjson.SetBytes(partDone, "content_index", 0) partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText) partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
out = append(out, emitRespEvent("response.content_part.done", partDone)) out = append(out, emitRespEvent("response.content_part.done", partDone))
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`) itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq()) itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i) itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText) itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
out = append(out, emitRespEvent("response.output_item.done", itemDone)) out = append(out, emitRespEvent("response.output_item.done", itemDone))
@@ -407,192 +570,45 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
// Emit function call done events for any active function calls // Emit function call done events for any active function calls
if len(st.FuncCallIDs) > 0 { if len(st.FuncCallIDs) > 0 {
idxs := make([]int, 0, len(st.FuncCallIDs)) keys := make([]string, 0, len(st.FuncCallIDs))
for i := range st.FuncCallIDs { for key := range st.FuncCallIDs {
idxs = append(idxs, i) keys = append(keys, key)
} }
for i := 0; i < len(idxs); i++ { sort.Slice(keys, func(i, j int) bool {
for j := i + 1; j < len(idxs); j++ { left := st.FuncOutputIx[keys[i]]
if idxs[j] < idxs[i] { right := st.FuncOutputIx[keys[j]]
idxs[i], idxs[j] = idxs[j], idxs[i] return left < right || (left == right && keys[i] < keys[j])
} })
} for _, key := range keys {
} callID := st.FuncCallIDs[key]
for _, i := range idxs { if callID == "" || st.FuncItemDone[key] {
callID := st.FuncCallIDs[i]
if callID == "" || st.FuncItemDone[i] {
continue continue
} }
outputIndex := st.FuncOutputIx[key]
args := "{}" args := "{}"
if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 { if b := st.FuncArgsBuf[key]; b != nil && b.Len() > 0 {
args = b.String() args = b.String()
} }
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`) fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq()) fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID)) fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
fcDone, _ = sjson.SetBytes(fcDone, "output_index", i) fcDone, _ = sjson.SetBytes(fcDone, "output_index", outputIndex)
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args) fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone)) out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`) itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq()) itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i) itemDone, _ = sjson.SetBytes(itemDone, "output_index", outputIndex)
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID)) itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args) itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID) itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID)
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[i]) itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[key])
out = append(out, emitRespEvent("response.output_item.done", itemDone)) out = append(out, emitRespEvent("response.output_item.done", itemDone))
st.FuncItemDone[i] = true st.FuncItemDone[key] = true
st.FuncArgsDone[i] = true st.FuncArgsDone[key] = true
} }
} }
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`) st.CompletionPending = true
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
}
}
// Build response.output using aggregated buffers
outputsWrapper := []byte(`{"arr":[]}`)
if len(st.Reasonings) > 0 {
for _, r := range st.Reasonings {
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
}
// Append message items in ascending index order
if len(st.MsgItemAdded) > 0 {
midxs := make([]int, 0, len(st.MsgItemAdded))
for i := range st.MsgItemAdded {
midxs = append(midxs, i)
}
for i := 0; i < len(midxs); i++ {
for j := i + 1; j < len(midxs); j++ {
if midxs[j] < midxs[i] {
midxs[i], midxs[j] = midxs[j], midxs[i]
}
}
}
for _, i := range midxs {
txt := ""
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
}
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.SetBytes(item, "content.0.text", txt)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
}
if len(st.FuncArgsBuf) > 0 {
idxs := make([]int, 0, len(st.FuncArgsBuf))
for i := range st.FuncArgsBuf {
idxs = append(idxs, i)
}
// small-N sort without extra imports
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, i := range idxs {
args := ""
if b := st.FuncArgsBuf[i]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[i]
name := st.FuncNames[i]
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
}
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
if st.UsageSeen {
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
if st.ReasoningTokens > 0 {
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
}
total := st.TotalTokens
if total == 0 {
total = st.PromptTokens + st.CompletionTokens
}
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
}
out = append(out, emitRespEvent("response.completed", completed))
} }
return true return true

View File

@@ -0,0 +1,423 @@
package responses
import (
"context"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Result) {
t.Helper()
lines := strings.Split(string(chunk), "\n")
if len(lines) < 2 {
t.Fatalf("unexpected SSE chunk: %q", chunk)
}
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
if !gjson.Valid(dataLine) {
t.Fatalf("invalid SSE data JSON: %q", dataLine)
}
return event, gjson.Parse(dataLine)
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_ResponseCompletedWaitsForDone(t *testing.T) {
t.Parallel()
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
tests := []struct {
name string
in []string
doneInputIndex int // Index in tt.in where the terminal [DONE] chunk arrives and response.completed must be emitted.
hasUsage bool
inputTokens int64
outputTokens int64
totalTokens int64
}{
{
// A provider may send finish_reason first and only attach usage in a later chunk (e.g. Vertex AI),
// so response.completed must wait for [DONE] to include that usage.
name: "late usage after finish reason",
in: []string{
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_late_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[],"usage":{"prompt_tokens":11,"completion_tokens":7,"total_tokens":18}}`,
`data: [DONE]`,
},
doneInputIndex: 3,
hasUsage: true,
inputTokens: 11,
outputTokens: 7,
totalTokens: 18,
},
{
// When usage arrives on the same chunk as finish_reason, we still expect a
// single response.completed event and it should remain deferred until [DONE].
name: "usage on finish reason chunk",
in: []string{
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_usage_same_chunk","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":13,"completion_tokens":5,"total_tokens":18}}`,
`data: [DONE]`,
},
doneInputIndex: 2,
hasUsage: true,
inputTokens: 13,
outputTokens: 5,
totalTokens: 18,
},
{
// An OpenAI-compatible streams from a buggy server might never send usage, so response.completed should
// still wait for [DONE] but omit the usage object entirely.
name: "no usage chunk",
in: []string{
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_no_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
`data: [DONE]`,
},
doneInputIndex: 2,
hasUsage: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
completedCount := 0
completedInputIndex := -1
var completedData gjson.Result
// Reuse converter state across input lines to simulate one streaming response.
var param any
for i, line := range tt.in {
// One upstream chunk can emit multiple downstream SSE events.
for _, chunk := range ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param) {
event, data := parseOpenAIResponsesSSEEvent(t, chunk)
if event != "response.completed" {
continue
}
completedCount++
completedInputIndex = i
completedData = data
if i < tt.doneInputIndex {
t.Fatalf("unexpected early response.completed on input index %d", i)
}
}
}
if completedCount != 1 {
t.Fatalf("expected exactly 1 response.completed event, got %d", completedCount)
}
if completedInputIndex != tt.doneInputIndex {
t.Fatalf("expected response.completed on terminal [DONE] chunk at input index %d, got %d", tt.doneInputIndex, completedInputIndex)
}
// Missing upstream usage should stay omitted in the final completed event.
if !tt.hasUsage {
if completedData.Get("response.usage").Exists() {
t.Fatalf("expected response.completed to omit usage when none was provided, got %s", completedData.Get("response.usage").Raw)
}
return
}
// When usage is present, the final response.completed event must preserve the usage values.
if got := completedData.Get("response.usage.input_tokens").Int(); got != tt.inputTokens {
t.Fatalf("unexpected response.usage.input_tokens: got %d want %d", got, tt.inputTokens)
}
if got := completedData.Get("response.usage.output_tokens").Int(); got != tt.outputTokens {
t.Fatalf("unexpected response.usage.output_tokens: got %d want %d", got, tt.outputTokens)
}
if got := completedData.Get("response.usage.total_tokens").Int(); got != tt.totalTokens {
t.Fatalf("unexpected response.usage.total_tokens: got %d want %d", got, tt.totalTokens)
}
})
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
in := []string{
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\",\"limit\":400,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
var param any
var out [][]byte
for _, line := range in {
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param)...)
}
addedNames := map[string]string{}
doneArgs := map[string]string{}
doneNames := map[string]string{}
outputItems := map[string]gjson.Result{}
for _, chunk := range out {
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
switch ev {
case "response.output_item.added":
if data.Get("item.type").String() != "function_call" {
continue
}
addedNames[data.Get("item.call_id").String()] = data.Get("item.name").String()
case "response.output_item.done":
if data.Get("item.type").String() != "function_call" {
continue
}
callID := data.Get("item.call_id").String()
doneArgs[callID] = data.Get("item.arguments").String()
doneNames[callID] = data.Get("item.name").String()
case "response.completed":
output := data.Get("response.output")
for _, item := range output.Array() {
if item.Get("type").String() == "function_call" {
outputItems[item.Get("call_id").String()] = item
}
}
}
}
if len(addedNames) != 2 {
t.Fatalf("expected 2 function_call added events, got %d", len(addedNames))
}
if len(doneArgs) != 2 {
t.Fatalf("expected 2 function_call done events, got %d", len(doneArgs))
}
if addedNames["call_read"] != "read" {
t.Fatalf("unexpected added name for call_read: %q", addedNames["call_read"])
}
if addedNames["call_glob"] != "glob" {
t.Fatalf("unexpected added name for call_glob: %q", addedNames["call_glob"])
}
if !gjson.Valid(doneArgs["call_read"]) {
t.Fatalf("invalid JSON args for call_read: %q", doneArgs["call_read"])
}
if !gjson.Valid(doneArgs["call_glob"]) {
t.Fatalf("invalid JSON args for call_glob: %q", doneArgs["call_glob"])
}
if strings.Contains(doneArgs["call_read"], "}{") {
t.Fatalf("call_read args were concatenated: %q", doneArgs["call_read"])
}
if strings.Contains(doneArgs["call_glob"], "}{") {
t.Fatalf("call_glob args were concatenated: %q", doneArgs["call_glob"])
}
if doneNames["call_read"] != "read" {
t.Fatalf("unexpected done name for call_read: %q", doneNames["call_read"])
}
if doneNames["call_glob"] != "glob" {
t.Fatalf("unexpected done name for call_glob: %q", doneNames["call_glob"])
}
if got := gjson.Get(doneArgs["call_read"], "filePath").String(); got != `C:\repo` {
t.Fatalf("unexpected filePath for call_read: %q", got)
}
if got := gjson.Get(doneArgs["call_glob"], "path").String(); got != `C:\repo` {
t.Fatalf("unexpected path for call_glob: %q", got)
}
if got := gjson.Get(doneArgs["call_glob"], "pattern").String(); got != "*.{yml,yaml}" {
t.Fatalf("unexpected pattern for call_glob: %q", got)
}
if len(outputItems) != 2 {
t.Fatalf("expected 2 function_call items in response.output, got %d", len(outputItems))
}
if outputItems["call_read"].Get("name").String() != "read" {
t.Fatalf("unexpected response.output name for call_read: %q", outputItems["call_read"].Get("name").String())
}
if outputItems["call_glob"].Get("name").String() != "glob" {
t.Fatalf("unexpected response.output name for call_glob: %q", outputItems["call_glob"].Get("name").String())
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCallsUseDistinctOutputIndexes(t *testing.T) {
in := []string{
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
var param any
var out [][]byte
for _, line := range in {
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param)...)
}
type fcEvent struct {
outputIndex int64
name string
arguments string
}
added := map[string]fcEvent{}
done := map[string]fcEvent{}
for _, chunk := range out {
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
switch ev {
case "response.output_item.added":
if data.Get("item.type").String() != "function_call" {
continue
}
callID := data.Get("item.call_id").String()
added[callID] = fcEvent{
outputIndex: data.Get("output_index").Int(),
name: data.Get("item.name").String(),
}
case "response.output_item.done":
if data.Get("item.type").String() != "function_call" {
continue
}
callID := data.Get("item.call_id").String()
done[callID] = fcEvent{
outputIndex: data.Get("output_index").Int(),
name: data.Get("item.name").String(),
arguments: data.Get("item.arguments").String(),
}
}
}
if len(added) != 2 {
t.Fatalf("expected 2 function_call added events, got %d", len(added))
}
if len(done) != 2 {
t.Fatalf("expected 2 function_call done events, got %d", len(done))
}
if added["call_choice0"].name != "glob" {
t.Fatalf("unexpected added name for call_choice0: %q", added["call_choice0"].name)
}
if added["call_choice1"].name != "read" {
t.Fatalf("unexpected added name for call_choice1: %q", added["call_choice1"].name)
}
if added["call_choice0"].outputIndex == added["call_choice1"].outputIndex {
t.Fatalf("expected distinct output indexes for different choices, both got %d", added["call_choice0"].outputIndex)
}
if !gjson.Valid(done["call_choice0"].arguments) {
t.Fatalf("invalid JSON args for call_choice0: %q", done["call_choice0"].arguments)
}
if !gjson.Valid(done["call_choice1"].arguments) {
t.Fatalf("invalid JSON args for call_choice1: %q", done["call_choice1"].arguments)
}
if done["call_choice0"].outputIndex == done["call_choice1"].outputIndex {
t.Fatalf("expected distinct done output indexes for different choices, both got %d", done["call_choice0"].outputIndex)
}
if done["call_choice0"].name != "glob" {
t.Fatalf("unexpected done name for call_choice0: %q", done["call_choice0"].name)
}
if done["call_choice1"].name != "read" {
t.Fatalf("unexpected done name for call_choice1: %q", done["call_choice1"].name)
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndToolUseDistinctOutputIndexes(t *testing.T) {
in := []string{
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
var param any
var out [][]byte
for _, line := range in {
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param)...)
}
var messageOutputIndex int64 = -1
var toolOutputIndex int64 = -1
for _, chunk := range out {
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
if ev != "response.output_item.added" {
continue
}
switch data.Get("item.type").String() {
case "message":
if data.Get("item.id").String() == "msg_resp_mixed_0" {
messageOutputIndex = data.Get("output_index").Int()
}
case "function_call":
if data.Get("item.call_id").String() == "call_choice1" {
toolOutputIndex = data.Get("output_index").Int()
}
}
}
if messageOutputIndex < 0 {
t.Fatal("did not find message output index")
}
if toolOutputIndex < 0 {
t.Fatal("did not find tool output index")
}
if messageOutputIndex == toolOutputIndex {
t.Fatalf("expected distinct output indexes for message and tool call, both got %d", messageOutputIndex)
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneAndCompletedOutputStayAscending(t *testing.T) {
in := []string{
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
var param any
var out [][]byte
for _, line := range in {
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param)...)
}
var doneIndexes []int64
var completedOrder []string
for _, chunk := range out {
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
switch ev {
case "response.output_item.done":
if data.Get("item.type").String() == "function_call" {
doneIndexes = append(doneIndexes, data.Get("output_index").Int())
}
case "response.completed":
for _, item := range data.Get("response.output").Array() {
if item.Get("type").String() == "function_call" {
completedOrder = append(completedOrder, item.Get("call_id").String())
}
}
}
}
if len(doneIndexes) != 2 {
t.Fatalf("expected 2 function_call done indexes, got %d", len(doneIndexes))
}
if doneIndexes[0] >= doneIndexes[1] {
t.Fatalf("expected ascending done output indexes, got %v", doneIndexes)
}
if len(completedOrder) != 2 {
t.Fatalf("expected 2 function_call items in completed output, got %d", len(completedOrder))
}
if completedOrder[0] != "call_glob" || completedOrder[1] != "call_read" {
t.Fatalf("unexpected completed function_call order: %v", completedOrder)
}
}

View File

@@ -201,6 +201,7 @@ var zhStrings = map[string]string{
"usage_output": "输出", "usage_output": "输出",
"usage_cached": "缓存", "usage_cached": "缓存",
"usage_reasoning": "思考", "usage_reasoning": "思考",
"usage_time": "时间",
// ── Logs ── // ── Logs ──
"logs_title": "📋 日志", "logs_title": "📋 日志",
@@ -352,6 +353,7 @@ var enStrings = map[string]string{
"usage_output": "Output", "usage_output": "Output",
"usage_cached": "Cached", "usage_cached": "Cached",
"usage_reasoning": "Reasoning", "usage_reasoning": "Reasoning",
"usage_time": "Time",
// ── Logs ── // ── Logs ──
"logs_title": "📋 Logs", "logs_title": "📋 Logs",

Some files were not shown because too many files have changed in this diff Show More