Compare commits

...

147 Commits

Author SHA1 Message Date
Luis Pater
03a1bac898 Merge upstream v6.9.9 (PR #483) 2026-04-02 21:31:21 +08:00
Luis Pater
3171d524f0 docs: fix duplicated ProxyPal entry in README files 2026-04-02 21:22:40 +08:00
Luis Pater
3e78a8d500 Merge branch 'main' into dev 2026-04-02 21:21:26 +08:00
Luis Pater
fcba912cc4 Merge pull request #2492 from davidwushi1145/main
fix(responses): reassemble split SSE event/data frames before streaming
2026-04-02 21:20:31 +08:00
Luis Pater
7170eeea5f Merge pull request #2454 from buddingnewinsights/add-proxypal-to-readme
docs: add ProxyPal to "Who is with us?" section
2026-04-02 21:18:22 +08:00
Luis Pater
e3eb048c7a Merge pull request #2489 from Soein/upstream-pr
fix: 增强 Claude 反代检测对抗能力
2026-04-02 21:16:58 +08:00
Luis Pater
a59e92435b Merge pull request #2490 from router-for-me/logs
Refactor websocket logging and error handling
2026-04-02 20:47:31 +08:00
davidwushi1145
108895fc04 Harden Responses SSE framing against partial chunk boundaries
Follow-up review found two real framing hazards in the handler-layer
framer: it could flush a partial `data:` payload before the JSON was
complete, and it could inject an extra newline before chunks that
already began with `\n`/`\r\n`. This commit tightens the framer so it
only emits undelimited events when the buffered `data:` payload is
already valid JSON (or `[DONE]`), skips newline injection for chunks
that already start with a line break, and avoids the heavier
`bytes.Split` path while scanning SSE fields.

The regression suite now covers split `data:` payload chunks,
newline-prefixed chunks, and dropping incomplete trailing data on
flush, so the original Responses fix remains intact while the review
concerns are explicitly locked down.

Constraint: Keep the follow-up limited to handler-layer framing and tests
Rejected: Ignore the review and rely on current executor chunk shapes | leaves partial data payload corruption possible
Rejected: Build a fully generic SSE parser | wider change than needed for the identified risks
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Do not emit undelimited Responses SSE events unless buffered `data:` content is already complete and valid
Tested: /tmp/go1.26.1/go/bin/go test ./sdk/api/handlers/openai -count=1
Tested: /tmp/go1.26.1/go/bin/go test ./sdk/api/handlers -count=1
Tested: /tmp/go1.26.1/go/bin/go vet ./sdk/api/handlers/...
Not-tested: Full repository test suite outside sdk/api/handlers packages
2026-04-02 20:39:49 +08:00
davidwushi1145
abc293c642 Prevent malformed Responses SSE frames from breaking stream clients
Line-oriented upstream executors can emit `event:` and `data:` as
separate chunks, but the Responses handler had started terminating
each incoming chunk as a full SSE event. That split `response.created`
into an empty event plus a later data block, which broke downstream
clients like OpenClaw.

This keeps the fix in the handler layer: a small stateful framer now
buffers standalone `event:` lines until the matching `data:` arrives,
preserves already-framed events, and ignores delimiter-only leftovers.
The regression suite now covers split event/data framing, full-event
passthrough, terminal errors, and the bootstrap path that forwards
line-oriented openai-response streams from non-Codex executors too.

Constraint: Keep the fix localized to Responses handler framing instead of patching every executor
Rejected: Revert to v6.9.7 chunk writing | would reintroduce data-only framing regressions
Rejected: Patch each line-oriented executor separately | duplicates fragile SSE assembly logic
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Do not assume incoming Responses stream chunks are already complete SSE events; preserve handler-layer reassembly for split `event:`/`data:` inputs
Tested: /tmp/go1.26.1/go/bin/go test ./sdk/api/handlers/openai -count=1
Tested: /tmp/go1.26.1/go/bin/go test ./sdk/api/handlers -count=1
Tested: /tmp/go1.26.1/go test ./sdk/api/handlers/... -count=1
Tested: /tmp/go1.26.1/go/bin/go vet ./sdk/api/handlers/...
Tested: Temporary patched server on 127.0.0.1:18317 -> /v1/models 200, /v1/responses non-stream 200, /v1/responses stream emitted combined `event:` + `data:` frames
Not-tested: Full repository test suite outside sdk/api/handlers packages
2026-04-02 20:26:42 +08:00
pzy
bb44671845 fix: 修复反代检测对抗的 3 个问题
- computeFingerprint 使用 rune 索引替代字节索引,修复多字节字符指纹不匹配
- utls Chrome TLS 指纹仅对 Anthropic 官方域名生效,自定义 base_url 走标准 transport
- IPv6 地址使用 net.JoinHostPort 正确拼接端口
2026-04-02 19:12:55 +08:00
Luis Pater
09e480036a feat(auth): add support for managing custom headers in auth files
Closes #2457
2026-04-02 19:11:09 +08:00
pzy
249f969110 fix: Claude API 请求使用 utls Chrome TLS 指纹
Claude executor 的 API 请求之前使用 Go 标准库 crypto/tls,JA3 指纹
与真实 Claude Code(Bun/BoringSSL)不匹配,可被 Cloudflare 识别。

- 新增 helps/utls_client.go,封装 utls Chrome 指纹 + HTTP/2 + 代理支持
- Claude executor 的 4 处 NewProxyAwareHTTPClient 替换为 NewUtlsHTTPClient
- 其他 executor(Gemini/Codex/iFlow 等)不受影响,仍用标准 TLS
- 非 HTTPS 请求自动回退到标准 transport
2026-04-02 19:09:56 +08:00
hkfires
4f8acec2d8 refactor(logging): centralize websocket handshake recording 2026-04-02 18:39:32 +08:00
hkfires
34339f61ee Refactor websocket logging and error handling
- Introduced new logging functions for websocket requests, handshakes, errors, and responses in `logging_helpers.go`.
- Updated `CodexWebsocketsExecutor` to utilize the new logging functions for improved clarity and consistency in websocket operations.
- Modified the handling of websocket upgrade rejections to log relevant metadata.
- Changed the request body key to a timeline body key in `openai_responses_websocket.go` to better reflect its purpose.
- Enhanced tests to verify the correct logging of websocket events and responses, including disconnect events and error handling scenarios.
2026-04-02 17:30:51 +08:00
pzy
4045378cb4 fix: 增强 Claude 反代检测对抗能力
基于 Claude Code v2.1.88 源码分析,修复多个可被 Anthropic 检测的差距:

- 实现消息指纹算法(SHA256 盐值 + 字符索引),替代随机 buildHash
- billing header cc_version 从设备 profile 动态取版本号,不再硬编码
- billing header cc_entrypoint 从客户端 UA 解析,支持 cli/vscode/local-agent
- billing header 新增 cc_workload 支持(通过 X-CPA-Claude-Workload 头传入)
- 新增 X-Claude-Code-Session-Id 头(每 apiKey 缓存 UUID,TTL=1h)
- 新增 x-client-request-id 头(仅 api.anthropic.com,每请求 UUID)
- 补全 4 个缺失的 beta flags(structured-outputs/fast-mode/redact-thinking/token-efficient-tools)
- OAuth scope 对齐 Claude Code 2.1.88(移除 org:create_api_key,添加 sessions/mcp/file_upload)
- Anthropic-Dangerous-Direct-Browser-Access 仅在 API key 模式发送
- 响应头网关指纹清洗(剥离 litellm/helicone/portkey/cloudflare/kong/braintrust 前缀头)
2026-04-02 15:55:22 +08:00
Luis Pater
2df35449fe Fix executor compat helpers 2026-04-02 12:20:12 +08:00
Luis Pater
c744179645 Merge PR #479 2026-04-02 12:15:33 +08:00
Luis Pater
9720b03a6b Merge pull request #477 from ben-vargas/plus-main
fix(copilot): route Gemini preview models to chat endpoint and correct context lengths
2026-04-02 11:36:51 +08:00
Luis Pater
f2c0f3d325 Merge pull request #476 from hungthai1401/fix/ghc-gpt54mini
Fix GitHub Copilot gpt-5.4-mini endpoint routing
2026-04-02 11:36:26 +08:00
Luis Pater
4f99bc54f1 test: update codex header expectations 2026-04-02 11:19:37 +08:00
Luis Pater
913f4a9c5f test: fix executor tests after helpers refactor 2026-04-02 11:12:30 +08:00
Luis Pater
25d1c18a3f fix: scope experimental cch signing to billing header 2026-04-02 11:03:11 +08:00
Luis Pater
d09dd4d0b2 Merge commit '15c2f274ea690c9a7c9db22f9f454af869db5375' into dev 2026-04-02 10:59:54 +08:00
Luis Pater
474fb042da Merge pull request #2476 from router-for-me/cherry-pick/pr-2438-to-dev
Cherry-pick PR #2438 onto dev
2026-04-02 10:36:50 +08:00
Michael
8435c3d7be feat(tui): show time in usage details 2026-04-02 10:35:13 +08:00
Luis Pater
e783d0a62e Merge pull request #2441 from MonsterQiu/issue-2421-alias-before-suspension
fix(auth): resolve oauth aliases before suspension checks
2026-04-02 10:27:39 +08:00
Luis Pater
b05f575e9b Merge pull request #2444 from 0oAstro/fix/codex-nonstream-finish-reason-tool-calls
fix(codex): set finish_reason to "tool_calls" in non-streaming response when tool calls are present
2026-04-02 10:01:25 +08:00
edlsh
15c2f274ea fix: preserve cloak config defaults when mode omitted 2026-04-01 13:20:11 -04:00
edlsh
37249339ac feat: add opt-in experimental Claude cch signing 2026-04-01 13:03:17 -04:00
Luis Pater
c422d16beb Merge pull request #2398 from 7RPH/fix/responses-sse-framing
fix: preserve SSE event boundaries for Responses streams
2026-04-02 00:46:51 +08:00
Luis Pater
66cd50f603 Merge pull request #2468 from router-for-me/ip
fix(openai): improve client IP retrieval in websocket handler
2026-04-02 00:03:35 +08:00
hkfires
caa529c282 fix(openai): improve client IP retrieval in websocket handler 2026-04-01 20:16:01 +08:00
hkfires
51a4379bf4 refactor(openai): remove websocket body log truncation limit 2026-04-01 18:11:43 +08:00
Luis Pater
acf98ed10e fix(openai): add session reference counter and cache lifecycle management for websocket tools 2026-04-01 17:28:50 +08:00
Luis Pater
d1c07a091e fix(openai): add websocket tool call repair with caching and tests to improve transcript consistency 2026-04-01 17:16:49 +08:00
Ben Vargas
c1a8adf1ab feat(registry): add GitHub Copilot gemini-3.1-pro-preview model 2026-04-01 01:25:03 -06:00
Ben Vargas
08e078fc25 fix(openai): route copilot Gemini preview models to chat endpoint 2026-04-01 01:24:58 -06:00
Luis Pater
105a21548f fix(codex): centralize session management with global store and add tests for executor session lifecycle 2026-04-01 13:17:10 +08:00
Luis Pater
1734aa1664 fix(codex): prioritize websocket-enabled credentials across priority tiers in scheduler logic 2026-04-01 12:51:12 +08:00
Luis Pater
ca11b236a7 refactor(runtime, openai): simplify header management and remove redundant websocket logging logic 2026-04-01 11:57:31 +08:00
huynhgiabuu
6fdff8227d docs: add ProxyPal to 'Who is with us?' section
Add ProxyPal (https://github.com/buddingnewinsights/proxypal) to the
community projects list in all three README files (EN, CN, JA).
Placed after CCS, restoring its original position.

ProxyPal is a cross-platform desktop app (macOS, Windows, Linux) that
wraps CLIProxyAPI with a native GUI, supporting multiple AI providers,
usage analytics, request monitoring, and auto-configuration for popular
coding tools.

Closes #2420
2026-04-01 10:23:22 +07:00
Luis Pater
330e12d3c2 fix(codex): conditionally set Session_id header for Mac OS user agents and clean up redundant logic 2026-04-01 11:11:45 +08:00
Thai Nguyen Hung
bd09c0bf09 feat(registry): add gpt-5.4-mini model to GitHub Copilot registry 2026-04-01 10:04:38 +07:00
Luis Pater
b468ca79c3 Merge branch 'dev' of github.com:router-for-me/CLIProxyAPI into dev 2026-04-01 03:09:03 +08:00
Luis Pater
d2c7e4e96a refactor(runtime): move executor utilities to helps package and update references 2026-04-01 03:08:20 +08:00
Luis Pater
1c7003ff68 Merge pull request #2452 from Lucaszmv/fix-qwen-cli-v0.13.2
fix(qwen): update CLI simulation to v0.13.2 and adjust header casing
2026-04-01 02:44:27 +08:00
Lucaszmv
1b44364e78 fix(qwen): update CLI simulation to v0.13.2 2026-03-31 15:19:07 -03:00
0oAstro
ec77f4a4f5 fix(codex): set finish_reason to tool_calls in non-streaming response when tool calls are present 2026-03-31 14:12:15 +05:30
MonsterQiu
f611dd6e96 refactor(auth): dedupe route-aware model support checks 2026-03-31 15:42:25 +08:00
MonsterQiu
07b7c1a1e0 fix(auth): resolve oauth aliases before suspension checks 2026-03-31 14:27:14 +08:00
Luis Pater
51fd58d74f fix(codex): use normalizeCodexInstructions to set default instructions 2026-03-31 12:16:57 +08:00
Luis Pater
faae9c2f7c Merge pull request #2422 from MonsterQiu/fix/codex-compact-instructions
fix(codex): add default instructions for /responses/compact
2026-03-31 12:14:20 +08:00
Luis Pater
bc3a6e4646 Merge pull request #2434 from MonsterQiu/fix/codex-responses-null-instructions
fix(codex): normalize null instructions for /responses requests
2026-03-31 12:01:21 +08:00
Luis Pater
b09b03e35e Merge pull request #2424 from possible055/fix/websocket-transcript-replacement
fix(openai): handle transcript replacement after websocket v2 compaction
2026-03-31 11:00:33 +08:00
Luis Pater
16231947e7 Merge pull request #2426 from xixiwenxuanhe/feature/antigravity-credits
feat(antigravity): add AI credits quota fallback
2026-03-31 10:51:40 +08:00
MonsterQiu
39b9a38fbc fix(codex): normalize null instructions across responses paths 2026-03-31 10:32:39 +08:00
MonsterQiu
bd855abec9 fix(codex): normalize null instructions for responses requests 2026-03-31 10:29:02 +08:00
Luis Pater
7c3c2e9f64 Merge pull request #2417 from CharTyr/fix/amp-streaming-thinking-regression
fix(amp): 修复流式响应中 thinking block 被错误抑制导致的 TUI 空白回复
2026-03-31 10:12:13 +08:00
Luis Pater
c10f8ae2e2 Fixed: #2420
docs(readme): remove ProxyPal section from all README translations
2026-03-31 07:23:02 +08:00
xixiwenxuanhe
a0bf33eca6 fix(antigravity): preserve fallback and honor config gate 2026-03-31 00:14:05 +08:00
xixiwenxuanhe
88dd9c715d feat(antigravity): add AI credits quota fallback 2026-03-30 23:58:12 +08:00
apparition
a3e21df814 fix(openai): avoid developer transcript resets
- Narrow websocket transcript replacement detection to assistant outputs and function calls
- Preserve existing merge behavior for follow-up developer messages without previous_response_id
- Add a regression test covering mid-session developer message updates
2026-03-30 23:33:16 +08:00
MonsterQiu
d3b94c9241 fix(codex): normalize null instructions for compact requests 2026-03-30 22:58:05 +08:00
apparition
c1d7599829 fix(openai): handle transcript replacement after websocket compaction
- Add shouldReplaceWebsocketTranscript() to detect historical model output in input
- Add normalizeResponseTranscriptReplacement() for full transcript reset handling
- Prevent duplicate stale turn-state when clients replace local history post-compaction
- Avoid orphaned function_call items from incremental append on compact transcripts
- Add unit tests for transcript replacement detection and state reset behavior
2026-03-30 22:44:58 +08:00
MonsterQiu
d11936f292 fix(codex): add default instructions for /responses/compact 2026-03-30 22:44:46 +08:00
Luis Pater
17363edf25 fix(auth): skip downtime for request-scoped 404 errors in model state management 2026-03-30 22:22:42 +08:00
CharTyr
279cbbbb8a fix(amp): don't suppress thinking blocks in streaming mode
Reverts the streaming thinking suppression introduced in b15453c.
rewriteStreamEvent should only inject signatures and rewrite model
names — suppressing thinking blocks in streaming mode breaks SSE
index alignment and causes the Amp TUI to render empty responses
on the second message onward (especially with model-mapped
non-Claude providers like GPT-5.4).

Non-streaming responses still suppress thinking when tool_use is
present via rewriteModelInResponse.
2026-03-30 20:09:32 +08:00
Luis Pater
486cd4c343 Merge pull request #2409 from sususu98/fix/tool-use-pairing-break
fix(antigravity): reorder model parts to prevent tool_use↔tool_result pairing breakage
2026-03-30 16:59:46 +08:00
sususu98
25feceb783 fix(antigravity): reorder model parts to prevent tool_use↔tool_result pairing breakage
When a Claude assistant message contains [text, tool_use, text], the
Antigravity API internally splits the model message at functionCall
boundaries, creating an extra assistant turn between tool_use and the
following tool_result. Claude then rejects with:

  tool_use ids were found without tool_result blocks immediately after

Fix: extend the existing 2-way part reordering (thinking-first) to a
3-way partition: thinking → regular → functionCall. This ensures
functionCall parts are always last, so Antigravity's split cannot
insert an extra assistant turn before the user's tool_result.

Fixes #989
2026-03-30 15:09:33 +08:00
Luis Pater
d26752250d Merge pull request #2403 from CharTyr/clean-pr
fix(amp): 修复Amp CLI 集成 缺失/无效 signature 导致的 TUI 崩溃与上游 400 问题
2026-03-30 12:54:15 +08:00
CharTyr
b15453c369 fix(amp): address PR review - stream thinking suppression, SSE detection, test init
- Call suppressAmpThinking in rewriteStreamEvent for streaming path
- Handle nil return from suppressAmpThinking to skip suppressed events
- Narrow looksLikeSSEChunk to line-prefix detection (HasPrefix vs Contains)
- Initialize suppressedContentBlock map in test
2026-03-30 00:42:04 -04:00
CharTyr
04ba8c8bc3 feat(amp): sanitize signatures and handle stream suppression for Amp compatibility 2026-03-29 22:23:18 -04:00
Luis Pater
6570692291 Merge pull request #2400 from router-for-me/revert-2374-codex-cache-clean
Revert "fix(codex): restore prompt cache continuity for Codex requests"
2026-03-29 22:19:39 +08:00
trph
f73d55ddaa fix: simplify responses SSE suffix handling 2026-03-29 22:19:25 +08:00
Luis Pater
13aa5b3375 Revert "fix(codex): restore prompt cache continuity for Codex requests" 2026-03-29 22:18:14 +08:00
trph
0fcc02fbea fix: tighten responses SSE review follow-up 2026-03-29 22:10:28 +08:00
trph
c03883ccf0 fix: address responses SSE review feedback 2026-03-29 22:00:46 +08:00
trph
134a9eac9d fix: preserve SSE event boundaries for Responses streams 2026-03-29 17:23:16 +08:00
Luis Pater
6d8de0ade4 feat(auth): implement weighted provider rotation for improved scheduling fairness 2026-03-29 13:49:01 +08:00
Luis Pater
1587ff5e74 Merge pull request #2389 from router-for-me/claude
fix(claude): add default max_tokens for models
2026-03-29 13:03:20 +08:00
hkfires
f033d3a6df fix(claude): enhance ensureModelMaxTokens to use registered max_completion_tokens and fallback to default 2026-03-29 13:00:43 +08:00
hkfires
145e0e0b5d fix(claude): add default max_tokens for models 2026-03-29 12:46:00 +08:00
Luis Pater
f8d1bc06ea Merge pull request #469 from router-for-me/plus
v6.9.5
2026-03-29 12:40:26 +08:00
Luis Pater
d5930f4e44 Merge branch 'main' into plus 2026-03-29 12:40:17 +08:00
Luis Pater
9b7d7021af docs(readme): update LingtrueAPI link in all README translations 2026-03-29 12:30:24 +08:00
Luis Pater
e41c22ef44 docs(readme): add LingtrueAPI sponsorship details to all README translations 2026-03-29 12:23:37 +08:00
Luis Pater
55271403fb Merge pull request #2374 from VooDisss/codex-cache-clean
fix(codex): restore prompt cache continuity for Codex requests
2026-03-28 21:16:51 +08:00
Luis Pater
36fba66619 Merge pull request #2371 from RaviTharuma/docs/provider-specific-routes
docs: clarify provider-specific routing for aliased models
2026-03-28 21:11:29 +08:00
Luis Pater
b9b127a7ea Merge pull request #2347 from edlsh/fix/codex-strip-stream-options
fix(codex): strip stream_options from Responses API requests
2026-03-28 21:03:01 +08:00
Luis Pater
2741e7b7b3 Merge pull request #2346 from pjpjq/codex/fix-codex-capacity-retry
fix(codex): Treat Codex capacity errors as retryable
2026-03-28 21:00:50 +08:00
Luis Pater
1767a56d4f Merge pull request #2343 from kongkk233/fix/proxy-transport-defaults
Preserve default transport settings for proxy clients
2026-03-28 20:58:24 +08:00
Luis Pater
779e6c2d2f Merge pull request #2231 from 7RPH/fix/responses-stream-multi-tool-calls
fix: preserve separate streamed tool calls in Responses API
2026-03-28 20:53:19 +08:00
Luis Pater
73c831747b Merge pull request #2133 from DragonFSKY/fix/2061-stale-modelstates
fix(auth): prevent stale runtime state inheritance from disabled auth entries
2026-03-28 20:50:57 +08:00
Luis Pater
b8b89f34f4 Merge pull request #442 from LuxVTZ/feat/gitlab-duo-panel-parity
Improve GitLab Duo gateway compatibility\n\nRestore internal/runtime/executor/claude_executor.go to main during merge.
2026-03-28 05:06:41 +08:00
Luis Pater
1fa094dac6 Merge pull request #461 from MrHuangJser/main
feat(cursor): Full Cursor provider with H2 streaming, MCP tools, multi-turn & multi-account
2026-03-28 05:01:27 +08:00
Luis Pater
f55754621f Merge pull request #464 from router-for-me/plus
v6.9.4
2026-03-28 04:51:27 +08:00
Luis Pater
ac26e7db43 Merge branch 'main' into plus 2026-03-28 04:51:18 +08:00
Luis Pater
10b824fcac fix(security): validate auth file names to prevent unsafe input 2026-03-28 04:48:23 +08:00
VooDisss
e5d3541b5a refactor(codex): remove stale affinity cleanup leftovers
Drop the last affinity-related executor artifacts so the PR stays focused on the minimal Codex continuity fix set: stable prompt cache identity, stable session_id, and the executor-only behavior that was validated to restore cache reads.
2026-03-27 20:40:26 +02:00
VooDisss
79755e76ea refactor(pr): remove forbidden translator changes
Drop the chat-completions translator edits from this PR so the branch complies with the repository policy that forbids pull-request changes under internal/translator. The remaining PR stays focused on the executor-level Codex continuity fix that was validated to restore cache reuse.
2026-03-27 19:34:13 +02:00
VooDisss
35f158d526 refactor(pr): narrow Codex cache fix scope
Remove the experimental auth-affinity routing changes from this PR so it stays focused on the validated Codex continuity fix. This keeps the prompt-cache repair while avoiding unrelated routing-policy concerns such as provider/model affinity scope, lifecycle cleanup, and hard-pin fallback semantics.
2026-03-27 19:06:34 +02:00
VooDisss
6962e09dd9 fix(auth): scope affinity by provider
Keep sticky auth affinity limited to matching providers and stop persisting execution-session IDs as long-lived affinity keys so provider switching and normal streaming traffic do not create incorrect pins or stale affinity state.
2026-03-27 18:52:58 +02:00
VooDisss
4c4cbd44da fix(auth): avoid leaking or over-persisting affinity keys
Stop using one-shot idempotency keys as long-lived auth-affinity identifiers and remove raw affinity-key values from debug logs so sticky routing keeps its continuity benefits without creating avoidable memory growth or credential exposure risks.
2026-03-27 18:34:51 +02:00
VooDisss
26eca8b6ba fix(codex): preserve continuity and safe affinity fallback
Restore Claude continuity after the continuity refactor, keep auth-affinity keys out of upstream Codex session identifiers, and only persist affinity after successful execution so retries can still rotate to healthy credentials when the first auth fails.
2026-03-27 18:27:33 +02:00
VooDisss
62b17f40a1 refactor(codex): align continuity helpers with review feedback
Align websocket continuity resolution with the HTTP Codex path, make auth-affinity principal keys use a stable string representation, and extract small helpers that remove duplicated continuity and affinity logic without changing the validated cache-hit behavior.
2026-03-27 18:11:57 +02:00
VooDisss
511b8a992e fix(codex): restore prompt cache continuity for Codex requests
Prompt caching on Codex was not reliably reusable through the proxy because repeated chat-completions requests could reach the upstream without the same continuity envelope. In practice this showed up most clearly with OpenCode, where cache reads worked in the reference client but not through CLIProxyAPI, although the root cause is broader than OpenCode itself.

The proxy was breaking continuity in several ways: executor-layer Codex request preparation stripped prompt_cache_retention, chat-completions translation did not preserve that field, continuity headers used a different shape than the working client behavior, and OpenAI-style Codex requests could be sent without a stable prompt_cache_key. When that happened, session_id fell back to a fresh random value per request, so upstream Codex treated repeated requests as unrelated turns instead of as part of the same cacheable context.

This change fixes that by preserving caller-provided prompt_cache_retention on Codex execution paths, preserving prompt_cache_retention when translating OpenAI chat-completions requests to Codex, aligning Codex continuity headers to session_id, and introducing an explicit Codex continuity policy that derives a stable continuity key from the best available signal. The resolution order prefers an explicit prompt_cache_key, then execution session metadata, then an explicit idempotency key, then stable request-affinity metadata, then a stable client-principal hash, and finally a stable auth-ID hash when no better continuity signal exists.

The same continuity key is applied to both prompt_cache_key in the request body and session_id in the request headers so repeated requests reuse the same upstream cache/session identity. The auth manager also keeps auth selection sticky for repeated request sequences, preventing otherwise-equivalent Codex requests from drifting across different upstream auth contexts and accidentally breaking cache reuse.

To keep the implementation maintainable, the continuity resolution and diagnostics are centralized in a dedicated Codex continuity helper instead of being scattered across executor flow code. Regression coverage now verifies retention preservation, continuity-key precedence, stable auth-ID fallback, websocket parity, translator preservation, and auth-affinity behavior. Manual validation confirmed prompt cache reads now occur through CLIProxyAPI when using Codex via OpenCode, and the fix should also benefit other clients that rely on stable repeated Codex request continuity.
2026-03-27 17:49:29 +02:00
Luis Pater
7dccc7ba2f docs(readme): remove redundant whitespace in BmoPlus sponsorship section of Chinese README 2026-03-27 20:52:14 +08:00
Luis Pater
70c90687fd docs(readme): fix formatting in BmoPlus sponsorship section of Chinese README 2026-03-27 20:49:43 +08:00
Luis Pater
8144ffd5c8 Merge pull request #2370 from B3o/add-bmoplus-sponsor
docs: add BmoPlus sponsorship banners to READMEs
2026-03-27 20:48:22 +08:00
Ravi Tharuma
0ab977c236 docs: clarify provider path limitations 2026-03-27 11:13:08 +01:00
Ravi Tharuma
224f0de353 docs: neutralize provider-specific path wording 2026-03-27 11:11:06 +01:00
B3o
6b45d311ec add BmoPlus sponsorship banners to READMEs 2026-03-27 18:01:35 +08:00
Ravi Tharuma
d54de441d3 docs: clarify provider-specific routing for aliased models 2026-03-27 10:53:09 +01:00
MrHuangJser
7386a70724 feat(cursor): auto-identify accounts from JWT sub for multi-account support
Previously Cursor required a manual ?label=xxx parameter to distinguish
accounts (unlike Codex which auto-generates filenames from JWT claims).

Cursor JWTs contain a "sub" claim (e.g. "auth0|user_XXXX") that uniquely
identifies each account. Now we:

- Add ParseJWTSub() + SubToShortHash() to extract and hash the sub claim
- Refactor GetTokenExpiry() to share the new decodeJWTPayload() helper
- Update CredentialFileName(label, subHash) to auto-generate filenames
  from the sub hash when no explicit label is provided
  (e.g. "cursor.8f202e67.json" instead of always "cursor.json")
- Add DisplayLabel() for human-readable account identification
- Store "sub" in metadata for observability
- Update both management API handler and SDK authenticator

Same account always produces the same filename (deterministic), different
accounts get different files. Explicit ?label= still takes priority.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 17:40:02 +08:00
白金
1821bf7051 docs: add BmoPlus sponsorship banners to READMEs 2026-03-27 17:39:29 +08:00
Luis Pater
d42b5d4e78 docs(readme): update QQ group information in Chinese README 2026-03-27 11:46:21 +08:00
MrHuangJser
1b7447b682 feat(cursor): implement StatusError for conductor cooldown integration
Cursor executor errors were plain fmt.Errorf — the conductor couldn't
extract HTTP status codes, so exhausted accounts never entered cooldown.

Changes:
- Add ConnectError struct to proto/connect.go: ParseConnectEndStream now
  returns *ConnectError with Code/Message fields for precise matching
- Add cursorStatusErr implementing StatusError + RetryAfter interfaces
- Add classifyCursorError() with two-layer classification:
  Layer 1: exact match on ConnectError.Code (gRPC standard codes)
    resource_exhausted → 429, unauthenticated → 401,
    permission_denied → 403, unavailable → 503, internal → 500
  Layer 2: fuzzy string match for H2 errors (RST_STREAM → 502)
- Log all ConnectError code/message pairs for observing real server
  error codes (we have no samples yet)
- Wrap Execute and ExecuteStream error returns with classifyCursorError

Now the conductor properly marks Cursor auths as cooldown on quota errors,
enabling exponential backoff and round-robin failover.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 11:42:22 +08:00
MrHuangJser
40dee4453a feat(cursor): auto-migrate sessions to healthy account on quota exhaustion
When a Cursor account's quota is exhausted, sessions bound to it can now
seamlessly continue on a different account:

Layer 1 — Checkpoint decoupling:
  Key checkpoints by conversationId (not authID:conversationId). Store
  authID inside savedCheckpoint. On lookup, if auth changed, discard the
  stale checkpoint and flatten conversation history into userText.

Layer 2 — Cross-account session cleanup:
  When a request arrives for a conversation whose session belongs to a
  different (now-exhausted) auth, close the old H2 stream and remove
  the stale session to free resources.

Layer 3 — H2Stream.Err() exposure:
  New Err() method on H2Stream so callers can inspect RST_STREAM,
  GOAWAY, or other stream-level errors after closure.

Layer 4 — processH2SessionFrames error propagation:
  Returns error instead of bare return. Connect EndStream errors (quota,
  rate limit) are now propagated instead of being logged and swallowed.

Layer 5 — Pre-response transparent retry:
  If the stream fails before any data is sent to the client, return an
  error to the conductor so it retries with a different auth — fully
  transparent to the client.

Layer 6 — Post-response error logging:
  If the stream fails after data was already sent, log a warning. The
  conductor's existing cooldown mechanism ensures the next request routes
  to a healthy account.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 10:50:32 +08:00
MrHuangJser
8902e1cccb style(cursor): replace fmt.Print* with log package for consistent logging
Address Gemini Code Assist review feedback: use logrus log package
instead of fmt.Printf/Println in Cursor auth handlers and CLI for
unified log formatting and level control.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:03:32 +08:00
黄姜恒
de5fe71478 feat(cursor): multi-account routing with round-robin and session isolation
- Add cursor/filename.go for multi-account credential file naming
- Include auth.ID in session and checkpoint keys for per-account isolation
- Record authID in cursorSession, validate on resume to prevent cross-account access
- Management API /cursor-auth-url supports ?label= for creating named accounts
- Leverages existing conductor round-robin + failover framework

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:27:49 +08:00
黄姜恒
dcfbec2990 feat(cursor): add management API for Cursor OAuth authentication
- Add RequestCursorToken handler with PKCE + polling flow
- Register /v0/management/cursor-auth-url route
- Returns login URL + state for browser auth, polls in background
- Saves cursor.json with access/refresh tokens on success

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:10:07 +08:00
黄姜恒
c95620f90e feat(cursor): conversation checkpoint + session_id for multi-turn context
- Capture conversation_checkpoint_update from Cursor server (was ignored)
- Store checkpoint per conversationId, replay as conversation_state on next request
- Use protowire to embed raw checkpoint bytes directly (no deserialization)
- Extract session_id from Claude Code metadata for stable conversationId across resume
- Flatten conversation history into userText as fallback when no checkpoint available
- Use conversationId as session key for reliable tool call resume
- Add checkpoint TTL cleanup (30min)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:51:47 +08:00
edlsh
754f3bcbc3 fix(codex): strip stream_options from Responses API requests
The Codex/OpenAI Responses API does not support the stream_options
parameter. When clients (e.g. Amp CLI) include stream_options in their
requests, CLIProxyAPI forwards it as-is, causing a 400 error:

  {"detail":"Unsupported parameter: stream_options"}

Strip stream_options alongside the other unsupported parameters
(previous_response_id, prompt_cache_retention, safety_identifier)
in Execute, ExecuteStream, and CountTokens.
2026-03-25 11:58:36 -04:00
pjpj
36973d4a6f Handle Codex capacity errors as retryable 2026-03-25 23:25:31 +08:00
黄姜恒
9613f0b3f9 feat(cursor): deterministic conversation_id from Claude Code session cch
Extract the cch hash from Claude Code's billing header in the system
prompt (x-anthropic-billing-header: ...cch=XXXXX;) and use it to derive
a deterministic conversation_id instead of generating a random UUID.

Same Claude Code session → same cch → same conversation_id → Cursor
server can reuse conversation state across multiple turns, preserving
tool call results and other context without re-encoding history.

Also cleans up temporary debug logging from previous iterations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:29:49 +08:00
黄姜恒
274f29e26b fix(cursor): improve session key uniqueness for multi-session safety
Include system prompt prefix (first 200 chars) in session key derivation.
Claude Code sessions have unique system prompts containing cwd, session_id,
file paths, etc., making collisions between concurrent sessions from the
same user virtually impossible.

Session key now = SHA256(apiKey + model + systemPrompt[:200] + firstUserMsg)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:24:37 +08:00
黄姜恒
c8e79c3787 fix(cursor): prevent session key collision across users
Include client API key in session key derivation to prevent different
users sharing the same proxy from accidentally resuming each other's
H2 streams when they send identical first messages with the same model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:19:11 +08:00
黄姜恒
8afef43887 fix(cursor): preserve tool call context in multi-turn conversations
When an assistant message appears after tool results without a pending
user message, append it to the last turn's assistant text instead of
dropping it. Also add bakeToolResultsIntoTurns() to merge tool results
into turn context when no active H2 session exists for resume, ensuring
the model sees the full tool interaction history in follow-up requests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:15:24 +08:00
黄姜恒
c1083cbfc6 fix(cursor): MCP tool call resume, H2 flow control, and token usage
- Rewrite tool call mechanism from interrupt-resume to inline-wait mode:
  processH2SessionFrames no longer exits on mcpArgs; instead blocks on
  toolResultCh while continuing to handle KV/heartbeat messages, then
  sends MCP result and continues processing text in the same goroutine.
  Fixes the issue where server stopped generating text after resume.

- Add switchable output channel (outMu/currentOut) so first HTTP response
  closes after tool_calls+[DONE], and resumed text goes to a new channel
  returned by resumeWithToolResults. Reset streamParam on switch so
  Translator produces fresh message_start/content_block_start events.

- Implement send-side H2 flow control: track server's initial window size
  and WINDOW_UPDATE increments; Write() blocks when window exhausted.
  Fixes RST_STREAM FLOW_CONTROL_ERROR on large requests (178KB+).

- Decode new InteractionUpdate fields: TurnEndedUpdate (field 14) as
  stream termination signal, HeartbeatUpdate (field 13) silently ignored,
  TokenDeltaUpdate (field 8) for token usage tracking.

- Include token usage in final stop chunk (prompt_tokens estimated from
  payload size, completion_tokens from accumulated TokenDeltaUpdate deltas)
  so Claude CLI status bar shows non-zero token counts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:03:14 +08:00
kwz
c89d19b300 Preserve default transport settings for proxy clients 2026-03-25 15:33:09 +08:00
Luis Pater
1e6bc81cfd refactor(config): replace auto-update-panel with disable-auto-update-panel for clarity 2026-03-25 10:31:44 +08:00
Luis Pater
1a149475e0 Merge pull request #2293 from Xvvln/fix/management-asset-security
fix(security): harden management panel asset updater
2026-03-25 10:22:49 +08:00
Luis Pater
e5166841db Merge pull request #2310 from shellus/fix/claude-openai-system-top-level
fix: preserve OpenAI system messages as Claude top-level system
2026-03-25 10:21:18 +08:00
黄姜恒
19c52bcb60 feat: stash code 2026-03-25 10:14:14 +08:00
Luis Pater
bb9b2d1758 Merge pull request #2320 from cikichen/build/freebsd-support
build: add freebsd support for releases
2026-03-25 10:12:35 +08:00
simon
d312422ab4 build: add freebsd support to releases 2026-03-24 16:49:04 +08:00
GeJiaXiang
09c92aa0b5 fix: keep a fallback turn for system-only Claude inputs 2026-03-24 13:54:25 +08:00
GeJiaXiang
8c67b3ae64 test: verify remaining user message after system merge 2026-03-24 13:47:52 +08:00
GeJiaXiang
000e4ceb4e fix: map OpenAI system messages to Claude top-level system 2026-03-24 13:42:33 +08:00
trph
cc32f5ff61 fix: unify Responses output indexes for streamed items 2026-03-24 08:59:09 +08:00
trph
fbff68b9e0 fix: preserve choice-aware output indexes for streamed tool calls 2026-03-24 08:54:43 +08:00
trph
7e1a543b79 fix: preserve separate streamed tool calls in Responses API 2026-03-24 08:51:15 +08:00
Xvvln
7333619f15 fix: reject oversized downloads instead of truncating; warn on unverified fallback
- Read maxAssetDownloadSize+1 bytes and error if exceeded, preventing
  silent truncation that could write a broken management.html to disk
- Log explicit warning when fallback URL is used without digest
  verification, so users are aware of the reduced security guarantee
2026-03-24 00:27:44 +08:00
DragonFSKY
74b862d8b8 test(cliproxy): cover delete re-add stale state flow 2026-03-24 00:21:04 +08:00
Xvvln
2db8df8e38 fix(security): harden management panel asset updater
- Abort update when SHA256 digest mismatch is detected instead of
  logging a warning and proceeding (prevents MITM asset replacement)
- Cap asset download size to 10 MB via io.LimitReader (defense-in-depth
  against OOM from oversized responses)
- Add `auto-update-panel` config option (default: false) to make the
  periodic background updater opt-in; the panel is still downloaded
  on first access when missing, but no longer silently auto-updated
  every 3 hours unless explicitly enabled
2026-03-24 00:10:04 +08:00
DragonFSKY
5c817a9b42 fix(auth): prevent stale ModelStates inheritance from disabled auth entries
When an auth file is deleted and re-created with the same path/ID, the
new auth could inherit stale ModelStates (cooldown/backoff) from the
previously disabled entry, preventing it from being routed.

Gate runtime state inheritance (ModelStates, LastRefreshedAt,
NextRefreshAfter) on both existing and incoming auth being non-disabled
in Manager.Update and Service.applyCoreAuthAddOrUpdate.

Closes #2061
2026-03-14 23:46:23 +08:00
luxvtz
5da0decef6 Improve GitLab Duo gateway compatibility 2026-03-14 03:18:43 -07:00
128 changed files with 14498 additions and 1868 deletions

3
.gitignore vendored
View File

@@ -1,6 +1,7 @@
# Binaries
cli-proxy-api
cliproxy
/server
*.exe
@@ -36,13 +37,13 @@ GEMINI.md
# Tooling metadata
.vscode/*
.worktrees/
.codex/*
.claude/*
.gemini/*
.serena/*
.agent/*
.agents/*
.agents/*
.opencode/*
.idea/*
.bmad/*

View File

@@ -8,6 +8,7 @@ builds:
- linux
- windows
- darwin
- freebsd
goarch:
- amd64
- arm64

View File

@@ -14,6 +14,99 @@ This project only accepts pull requests that relate to third-party provider supp
If you need to submit any non-third-party provider changes, please open them against the [mainline](https://github.com/router-for-me/CLIProxyAPI) repository.
1. Fork the repository
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
4. Push to the branch (`git push origin feature/amazing-feature`)
5. Open a Pull Request
## Who is with us?
Those projects are based on CLIProxyAPI:
### [vibeproxy](https://github.com/automazeio/vibeproxy)
Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
### [Quotio](https://github.com/nguyenphutrong/quotio)
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
### [CodMate](https://github.com/loocor/CodMate)
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
Windows desktop app built with Tauri + React for monitoring AI coding assistant quotas via CLIProxyAPI. Track usage across Gemini, Claude, OpenAI Codex, and Antigravity accounts with real-time dashboard, system tray integration, and one-click proxy control - no API keys needed.
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
Browser extension for one-stop management of New API-compatible relay site accounts, featuring balance and usage dashboards, auto check-in, one-click key export to common apps, in-page API availability testing, and channel/model sync and redirection. It integrates with CLIProxyAPI through the Management API for one-click provider import and config sync.
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AI is an AI assistant tool designed specifically for restricted environments. It provides a stealthy operation
mode without windows or traces, and enables cross-device AI Q&A interaction and control via the local area network (
LAN). Essentially, it is an automated collaboration layer of "screen/audio capture + AI inference + low-friction delivery",
helping users to immersively use AI assistants across applications on controlled devices or in restricted environments.
### [ProxyPal](https://github.com/buddingnewinsights/proxypal)
Cross-platform desktop app (macOS, Windows, Linux) wrapping CLIProxyAPI with a native GUI. Connects Claude, ChatGPT, Gemini, GitHub Copilot, Qwen, iFlow, and custom OpenAI-compatible endpoints with usage analytics, request monitoring, and auto-configuration for popular coding tools - no API keys needed.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
## More choices
Those projects are ports of CLIProxyAPI or inspired by it:
### [9Router](https://github.com/decolua/9router)
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
Never stop coding. Smart routing to FREE & low-cost AI models with automatic fallback.
OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoint with smart routing, load balancing, retries, and fallbacks. Add policies, rate limits, caching, and observability for reliable, cost-aware inference.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

View File

@@ -12,6 +12,96 @@
如果需要提交任何非第三方供应商支持的 Pull Request请提交到[主线](https://github.com/router-for-me/CLIProxyAPI)版本。
1. Fork 仓库
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`
3. 提交您的更改(`git commit -m 'Add some amazing feature'`
4. 推送到分支(`git push origin feature/amazing-feature`
5. 打开 Pull Request
## 谁与我们在一起?
这些项目基于 CLIProxyAPI:
### [vibeproxy](https://github.com/automazeio/vibeproxy)
一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型Gemini, Codex, Antigravity无需 API 密钥。
### [Quotio](https://github.com/nguyenphutrong/quotio)
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
### [CodMate](https://github.com/loocor/CodMate)
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话Claude Code、Codex、Gemini CLI提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI 监控 AI 编程助手配额。支持跨 Gemini、Claude、OpenAI Codex 和 Antigravity 账户的使用量追踪,提供实时仪表盘、系统托盘集成和一键代理控制,无需 API 密钥。
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
Windows 托盘应用,基于 PowerShell 脚本实现不依赖任何第三方库。主要功能包括自动创建快捷方式、静默运行、密码管理、通道切换Main / Plus以及自动下载与更新。
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君是一款用于管理AI编程助手的跨平台桌面应用支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具本地代理实现多账户配额跟踪和一键配置。
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
用于一站式管理 New API 兼容中转站账号的浏览器扩展,提供余额与用量看板、自动签到、密钥一键导出到常用应用、网页内 API 可用性测试,以及渠道与模型同步和重定向。支持通过 CLIProxyAPI Management API 一键导入 Provider 与同步配置。
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AI 是一款专为受限环境设计的 AI 辅助工具。提供无窗口、无痕迹的隐蔽运行方式,并通过局域网实现跨设备的 AI 问答交互与控制。本质上是一个「屏幕/音频采集 + AI 推理 + 低摩擦投送」的自动化协作层,帮助用户在受控设备/受限环境下沉浸式跨应用地使用 AI 助手。
### [ProxyPal](https://github.com/buddingnewinsights/proxypal)
跨平台桌面应用macOS、Windows、Linux以原生 GUI 封装 CLIProxyAPI。支持连接 Claude、ChatGPT、Gemini、GitHub Copilot、Qwen、iFlow 及自定义 OpenAI 兼容端点,具备使用分析、请求监控和热门编程工具自动配置功能,无需 API 密钥。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。
## 更多选择
以下项目是 CLIProxyAPI 的移植版或受其启发:
### [9Router](https://github.com/decolua/9router)
基于 Next.js 的实现,灵感来自 CLIProxyAPI易于安装使用自研格式转换OpenAI/Claude/Gemini/Ollama、组合系统与自动回退、多账户管理指数退避、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
代码不止,创新不停。智能路由至免费及低成本 AI 模型,并支持自动故障转移。
OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼容 OpenAI 的端点,具备智能路由、负载均衡、重试及回退机制。通过添加策略、速率限制、缓存和可观测性,确保推理过程既可靠又具备成本意识。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -30,6 +30,14 @@ GLM CODING PLANを10%割引で取得https://z.ai/subscribe?ic=8JVLJQFSKB
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>AICodeMirrorのスポンサーシップに感謝しますAICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引がありますCLIProxyAPIユーザー向けの特別特典<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
</tr>
<tr>
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたしますBmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割90% OFF</b> という驚異的な価格でご利用いただけます!</td>
</tr>
<tr>
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
<td>LingtrueAPIのスポンサーシップに感謝しますLingtrueAPIはグローバルな大規模モデルAPIリレーサービスプラットフォームで、Claude Code、Codex、GeminiなどのトップモデルAPI呼び出しサービスを提供し、ユーザーが低コストかつ高い安定性で世界中のAI能力に接続できるよう支援しています。LingtrueAPIは本ソフトウェアのユーザーに特別割引を提供しています<a href="https://www.lingtrue.com/register">こちらのリンク</a>から登録し、初回チャージ時にプロモーションコード「LingtrueAPI」を入力すると10%割引になります。</td>
</tr>
</tbody>
</table>
@@ -74,6 +82,14 @@ CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5``claude-sonnet-4`
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
特定のバックエンド系統のリクエスト/レスポンス形状が必要な場合は、統合された `/v1/...` エンドポイントよりも provider-specific のパスを優先してください。
- messages 系のバックエンドには `/api/provider/{provider}/v1/messages`
- モデル単位の generate 系エンドポイントには `/api/provider/{provider}/v1beta/models/...`
- chat-completions 系のバックエンドには `/api/provider/{provider}/v1/chat/completions`
これらのパスはプロトコル面の選択には役立ちますが、同じクライアント向けモデル名が複数バックエンドで再利用されている場合、それだけで推論実行系が一意に固定されるわけではありません。実際の推論ルーティングは、引き続きリクエスト内の model/alias 解決に従います。厳密にバックエンドを固定したい場合は、一意な alias や prefix を使うか、クライアント向けモデル名の重複自体を避けてください。
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
## SDKドキュメント
@@ -110,10 +126,6 @@ CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデルGemini、Codex、Antigravityを即座に切り替えるCLIラッパー - APIキー不要
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
CLIProxyAPI管理用のmacOSネイティブGUIOAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
### [Quotio](https://github.com/nguyenphutrong/quotio)
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
@@ -158,6 +170,10 @@ New API互換リレーサイトアカウントをワンストップで管理す
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LANローカルエリアネットワークを介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
### [ProxyPal](https://github.com/buddingnewinsights/proxypal)
CLIProxyAPIをネイティブGUIでラップしたクロスプラットフォームデスクトップアプリmacOS、Windows、Linux。Claude、ChatGPT、Gemini、GitHub Copilot、Qwen、iFlow、カスタムOpenAI互換エンドポイントに対応し、使用状況分析、リクエスト監視、人気コーディングツールの自動設定機能を搭載 - APIキー不要
> [!NOTE]
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。

BIN
assets/bmoplus.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

BIN
assets/lingtrue.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

20
cmd/mcpdebug/main.go Normal file
View File

@@ -0,0 +1,20 @@
package main
import (
"encoding/hex"
"fmt"
"os"
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
)
func main() {
// Encode MCP result with empty execId
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
// Write to file for analysis
os.WriteFile("mcp_result.bin", resultBytes)
fmt.Println("Wrote mcp_result.bin")
}

32
cmd/protocheck/main.go Normal file
View File

@@ -0,0 +1,32 @@
package main
import (
"fmt"
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
)
func main() {
ecm := cursorproto.NewMsg("ExecClientMessage")
// Try different field names
names := []string{
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
"shell_result", "shellResult",
}
for _, name := range names {
fd := ecm.Descriptor().Fields().ByName(name)
if fd != nil {
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
} else {
fmt.Printf("Field %q NOT FOUND\n", name)
}
}
// List all fields
fmt.Println("\nAll fields in ExecClientMessage:")
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
f := ecm.Descriptor().Fields().Get(i)
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
}
}

View File

@@ -85,6 +85,7 @@ func main() {
var oauthCallbackPort int
var antigravityLogin bool
var kimiLogin bool
var cursorLogin bool
var kiroLogin bool
var kiroGoogleLogin bool
var kiroAWSLogin bool
@@ -123,6 +124,7 @@ func main() {
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
@@ -544,6 +546,8 @@ func main() {
cmd.DoGitLabTokenLogin(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if cursorLogin {
cmd.DoCursorLogin(cfg, options)
} else if kiroLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito

View File

@@ -25,6 +25,10 @@ remote-management:
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# Disable automatic periodic background updates of the management panel from GitHub (default: false).
# When enabled, the panel is only downloaded on first access if missing, and never auto-updated afterward.
# disable-auto-update-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
@@ -92,6 +96,7 @@ max-retry-interval: 30
quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
antigravity-credits: true # Whether to retry Antigravity quota_exhausted 429s once with enabledCreditTypes=["GOOGLE_ONE_AI"]
# Routing strategy for selecting credentials when multiple match.
routing:
@@ -173,6 +178,8 @@ nonstream-keepalive-interval: 0
# - "API"
# - "proxy"
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# experimental-cch-signing: false # optional: default is false; when true, sign the final /v1/messages body using the current Claude Code cch algorithm
# # keep this disabled unless you explicitly need the behavior, so upstream seed changes fall back to legacy proxy behavior
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
@@ -309,6 +316,10 @@ nonstream-keepalive-interval: 0
# These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
# you select the protocol surface, but inference backend selection can still follow the resolved
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
# You can repeat the same name with different aliases to expose multiple client model names.
# oauth-model-alias:
# antigravity:

1
go.mod
View File

@@ -83,6 +83,7 @@ require (
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pierrec/xxHash v0.1.5
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect

2
go.sum
View File

@@ -154,6 +154,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pierrec/xxHash v0.1.5 h1:n/jBpwTHiER4xYvK3/CdPVnLDPchj8eTJFFLUb4QHBo=
github.com/pierrec/xxHash v0.1.5/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=

View File

@@ -29,6 +29,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
@@ -550,10 +551,23 @@ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
}
func isUnsafeAuthFileName(name string) bool {
if strings.TrimSpace(name) == "" {
return true
}
if strings.ContainsAny(name, "/\\") {
return true
}
if filepath.VolumeName(name) != "" {
return true
}
return false
}
// Download single auth file by name
func (h *Handler) DownloadAuthFile(c *gin.Context) {
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
name := strings.TrimSpace(c.Query("name"))
if isUnsafeAuthFileName(name) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
@@ -635,8 +649,8 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "no files uploaded"})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
name := strings.TrimSpace(c.Query("name"))
if isUnsafeAuthFileName(name) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
@@ -869,7 +883,7 @@ func uniqueAuthFileNames(names []string) []string {
func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string, int, error) {
name = strings.TrimSpace(name)
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
if isUnsafeAuthFileName(name) {
return "", http.StatusBadRequest, fmt.Errorf("invalid name")
}
@@ -1033,6 +1047,7 @@ func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Aut
auth.Runtime = existing.Runtime
}
}
coreauth.ApplyCustomHeadersFromMetadata(auth)
return auth, nil
}
@@ -1115,7 +1130,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
// PatchAuthFileFields updates editable fields (prefix, proxy_url, headers, priority, note) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
@@ -1123,11 +1138,12 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
}
var req struct {
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
Note *string `json:"note"`
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Headers map[string]string `json:"headers"`
Priority *int `json:"priority"`
Note *string `json:"note"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
@@ -1163,13 +1179,107 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
changed := false
if req.Prefix != nil {
targetAuth.Prefix = *req.Prefix
prefix := strings.TrimSpace(*req.Prefix)
targetAuth.Prefix = prefix
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if prefix == "" {
delete(targetAuth.Metadata, "prefix")
} else {
targetAuth.Metadata["prefix"] = prefix
}
changed = true
}
if req.ProxyURL != nil {
targetAuth.ProxyURL = *req.ProxyURL
proxyURL := strings.TrimSpace(*req.ProxyURL)
targetAuth.ProxyURL = proxyURL
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if proxyURL == "" {
delete(targetAuth.Metadata, "proxy_url")
} else {
targetAuth.Metadata["proxy_url"] = proxyURL
}
changed = true
}
if len(req.Headers) > 0 {
existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(targetAuth.Metadata)
nextHeaders := make(map[string]string, len(existingHeaders))
for k, v := range existingHeaders {
nextHeaders[k] = v
}
headerChanged := false
for key, value := range req.Headers {
name := strings.TrimSpace(key)
if name == "" {
continue
}
val := strings.TrimSpace(value)
attrKey := "header:" + name
if val == "" {
if _, ok := nextHeaders[name]; ok {
delete(nextHeaders, name)
headerChanged = true
}
if targetAuth.Attributes != nil {
if _, ok := targetAuth.Attributes[attrKey]; ok {
headerChanged = true
}
}
continue
}
if prev, ok := nextHeaders[name]; !ok || prev != val {
headerChanged = true
}
nextHeaders[name] = val
if targetAuth.Attributes != nil {
if prev, ok := targetAuth.Attributes[attrKey]; !ok || prev != val {
headerChanged = true
}
} else {
headerChanged = true
}
}
if headerChanged {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if targetAuth.Attributes == nil {
targetAuth.Attributes = make(map[string]string)
}
for key, value := range req.Headers {
name := strings.TrimSpace(key)
if name == "" {
continue
}
val := strings.TrimSpace(value)
attrKey := "header:" + name
if val == "" {
delete(nextHeaders, name)
delete(targetAuth.Attributes, attrKey)
continue
}
nextHeaders[name] = val
targetAuth.Attributes[attrKey] = val
}
if len(nextHeaders) == 0 {
delete(targetAuth.Metadata, "headers")
} else {
metaHeaders := make(map[string]any, len(nextHeaders))
for k, v := range nextHeaders {
metaHeaders[k] = v
}
targetAuth.Metadata["headers"] = metaHeaders
}
changed = true
}
}
if req.Priority != nil || req.Note != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
@@ -2124,9 +2234,6 @@ func (h *Handler) RequestGitLabToken(c *gin.Context) {
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
metadata["auth_kind"] = "oauth"
metadata["oauth_client_id"] = clientID
if clientSecret != "" {
metadata["oauth_client_secret"] = clientSecret
}
metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email
@@ -3694,3 +3801,84 @@ func (h *Handler) RequestKiloToken(c *gin.Context) {
"verification_uri": resp.VerificationURL,
})
}
// RequestCursorToken initiates the Cursor PKCE authentication flow.
// Supports multiple accounts via ?label=xxx query parameter.
// The user opens the returned URL in a browser, logs in, and the server polls
// until the authentication completes.
func (h *Handler) RequestCursorToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
label := strings.TrimSpace(c.Query("label"))
log.Infof("Initializing Cursor authentication (label=%q)...", label)
authParams, err := cursorauth.GenerateAuthParams()
if err != nil {
log.Errorf("Failed to generate Cursor auth params: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate auth params"})
return
}
state := fmt.Sprintf("cur-%d", time.Now().UnixNano())
RegisterOAuthSession(state, "cursor")
go func() {
log.Info("Waiting for Cursor authentication...")
log.Infof("Open this URL in your browser: %s", authParams.LoginURL)
tokens, errPoll := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
if errPoll != nil {
SetOAuthSessionError(state, "Authentication failed: "+errPoll.Error())
log.Errorf("Cursor authentication failed: %v", errPoll)
return
}
// Build metadata
metadata := map[string]any{
"type": "cursor",
"access_token": tokens.AccessToken,
"refresh_token": tokens.RefreshToken,
"timestamp": time.Now().UnixMilli(),
}
// Extract expiry and account identity from JWT
expiry := cursorauth.GetTokenExpiry(tokens.AccessToken)
if !expiry.IsZero() {
metadata["expires_at"] = expiry.Format(time.RFC3339)
}
// Auto-identify account from JWT sub claim for multi-account support
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
subHash := cursorauth.SubToShortHash(sub)
if sub != "" {
metadata["sub"] = sub
}
fileName := cursorauth.CredentialFileName(label, subHash)
displayLabel := cursorauth.DisplayLabel(label, subHash)
record := &coreauth.Auth{
ID: fileName,
Provider: "cursor",
FileName: fileName,
Label: displayLabel,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save Cursor tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save tokens")
return
}
log.Infof("Cursor authentication successful! Token saved to %s", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("cursor")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": authParams.LoginURL,
"state": state,
})
}

View File

@@ -0,0 +1,62 @@
package management
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
fileName := "download-user.json"
expected := []byte(`{"type":"codex"}`)
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
if got := rec.Body.Bytes(); string(got) != string(expected) {
t.Fatalf("unexpected download content: %q", string(got))
}
}
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
for _, name := range []string{
"../external/secret.json",
`..\\external\\secret.json`,
"nested/secret.json",
`nested\\secret.json`,
} {
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
}
}
}

View File

@@ -0,0 +1,51 @@
//go:build windows
package management
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
tempDir := t.TempDir()
authDir := filepath.Join(tempDir, "auth")
externalDir := filepath.Join(tempDir, "external")
if err := os.MkdirAll(authDir, 0o700); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
if err := os.MkdirAll(externalDir, 0o700); err != nil {
t.Fatalf("failed to create external dir: %v", err)
}
secretName := "secret.json"
secretPath := filepath.Join(externalDir, secretName)
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
t.Fatalf("failed to write external file: %v", err)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(
http.MethodGet,
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
nil,
)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
}
}

View File

@@ -0,0 +1,164 @@
package management
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestPatchAuthFileFields_MergeHeadersAndDeleteEmptyValues(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
record := &coreauth.Auth{
ID: "test.json",
FileName: "test.json",
Provider: "claude",
Attributes: map[string]string{
"path": "/tmp/test.json",
"header:X-Old": "old",
"header:X-Remove": "gone",
},
Metadata: map[string]any{
"type": "claude",
"headers": map[string]any{
"X-Old": "old",
"X-Remove": "gone",
},
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
body := `{"name":"test.json","prefix":"p1","proxy_url":"http://proxy.local","headers":{"X-Old":"new","X-New":"v","X-Remove":" ","X-Nope":""}}`
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
ctx.Request = req
h.PatchAuthFileFields(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
updated, ok := manager.GetByID("test.json")
if !ok || updated == nil {
t.Fatalf("expected auth record to exist after patch")
}
if updated.Prefix != "p1" {
t.Fatalf("prefix = %q, want %q", updated.Prefix, "p1")
}
if updated.ProxyURL != "http://proxy.local" {
t.Fatalf("proxy_url = %q, want %q", updated.ProxyURL, "http://proxy.local")
}
if updated.Metadata == nil {
t.Fatalf("expected metadata to be non-nil")
}
if got, _ := updated.Metadata["prefix"].(string); got != "p1" {
t.Fatalf("metadata.prefix = %q, want %q", got, "p1")
}
if got, _ := updated.Metadata["proxy_url"].(string); got != "http://proxy.local" {
t.Fatalf("metadata.proxy_url = %q, want %q", got, "http://proxy.local")
}
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
if !ok {
raw, _ := json.Marshal(updated.Metadata["headers"])
t.Fatalf("metadata.headers = %T (%s), want map[string]any", updated.Metadata["headers"], string(raw))
}
if got := headersMeta["X-Old"]; got != "new" {
t.Fatalf("metadata.headers.X-Old = %#v, want %q", got, "new")
}
if got := headersMeta["X-New"]; got != "v" {
t.Fatalf("metadata.headers.X-New = %#v, want %q", got, "v")
}
if _, ok := headersMeta["X-Remove"]; ok {
t.Fatalf("expected metadata.headers.X-Remove to be deleted")
}
if _, ok := headersMeta["X-Nope"]; ok {
t.Fatalf("expected metadata.headers.X-Nope to be absent")
}
if got := updated.Attributes["header:X-Old"]; got != "new" {
t.Fatalf("attrs header:X-Old = %q, want %q", got, "new")
}
if got := updated.Attributes["header:X-New"]; got != "v" {
t.Fatalf("attrs header:X-New = %q, want %q", got, "v")
}
if _, ok := updated.Attributes["header:X-Remove"]; ok {
t.Fatalf("expected attrs header:X-Remove to be deleted")
}
if _, ok := updated.Attributes["header:X-Nope"]; ok {
t.Fatalf("expected attrs header:X-Nope to be absent")
}
}
func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
record := &coreauth.Auth{
ID: "noop.json",
FileName: "noop.json",
Provider: "claude",
Attributes: map[string]string{
"path": "/tmp/noop.json",
"header:X-Kee": "1",
},
Metadata: map[string]any{
"type": "claude",
"headers": map[string]any{
"X-Kee": "1",
},
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
body := `{"name":"noop.json","note":"hello","headers":{}}`
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
ctx.Request = req
h.PatchAuthFileFields(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
updated, ok := manager.GetByID("noop.json")
if !ok || updated == nil {
t.Fatalf("expected auth record to exist after patch")
}
if got := updated.Attributes["header:X-Kee"]; got != "1" {
t.Fatalf("attrs header:X-Kee = %q, want %q", got, "1")
}
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
if !ok {
t.Fatalf("expected metadata.headers to remain a map, got %T", updated.Metadata["headers"])
}
if got := headersMeta["X-Kee"]; got != "1" {
t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1")
}
}

View File

@@ -15,6 +15,8 @@ import (
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
if len(apiResponse) > 0 {
_ = w.streamWriter.WriteAPIResponse(apiResponse)
}
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
if len(apiWebsocketTimeline) > 0 {
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
}
if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil
return err
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
return data
}
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
if !isExist {
return nil
}
data, ok := apiTimeline.([]byte)
if !ok || len(data) == 0 {
return nil
}
return bytes.Clone(data)
}
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
if !isExist {
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
}
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
return body
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
return body
}
if w.body == nil || w.body.Len() == 0 {
return nil
}
return bytes.Clone(w.body.Bytes())
}
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
}
func extractBodyOverride(c *gin.Context, key string) []byte {
if c == nil {
return nil
}
bodyOverride, isExist := c.Get(key)
if !isExist {
return nil
}
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode,
headers,
body,
websocketTimeline,
apiRequestBody,
apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors,
forceLog,
w.requestInfo.RequestID,
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode,
headers,
body,
websocketTimeline,
apiRequestBody,
apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors,
w.requestInfo.RequestID,
w.requestInfo.Timestamp,

View File

@@ -1,10 +1,14 @@
package middleware
import (
"bytes"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
wrapper.body.WriteString("original-response")
body := wrapper.extractResponseBody(c)
if string(body) != "original-response" {
t.Fatalf("response body = %q, want %q", string(body), "original-response")
}
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
body = wrapper.extractResponseBody(c)
if string(body) != "override-response" {
t.Fatalf("response body = %q, want %q", string(body), "override-response")
}
body[0] = 'X'
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
t.Fatalf("response override should be cloned, got %q", string(got))
}
}
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
body := wrapper.extractResponseBody(c)
if string(body) != "override-response-as-string" {
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
}
}
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
override := []byte("body-override")
c.Set(requestBodyOverrideContextKey, override)
body := extractBodyOverride(c, requestBodyOverrideContextKey)
if !bytes.Equal(body, override) {
t.Fatalf("body override = %q, want %q", string(body), string(override))
}
body[0] = 'X'
if !bytes.Equal(override, []byte("body-override")) {
t.Fatalf("override mutated: %q", string(override))
}
}
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
if got := wrapper.extractWebsocketTimeline(c); got != nil {
t.Fatalf("expected nil websocket timeline, got %q", string(got))
}
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
body := wrapper.extractWebsocketTimeline(c)
if string(body) != "timeline" {
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
}
}
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
streamWriter := &testStreamingLogWriter{}
wrapper := &ResponseWriterWrapper{
ResponseWriter: c.Writer,
logger: &testRequestLogger{enabled: true},
requestInfo: &RequestInfo{
URL: "/v1/responses",
Method: "POST",
Headers: map[string][]string{"Content-Type": {"application/json"}},
RequestID: "req-1",
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
},
isStreaming: true,
streamWriter: streamWriter,
}
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
if err := wrapper.Finalize(c); err != nil {
t.Fatalf("Finalize error: %v", err)
}
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
}
if !streamWriter.closed {
t.Fatal("expected stream writer to be closed")
}
}
type testRequestLogger struct {
enabled bool
}
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
return nil
}
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
return &testStreamingLogWriter{}, nil
}
func (l *testRequestLogger) IsEnabled() bool {
return l.enabled
}
type testStreamingLogWriter struct {
apiWebsocketTimeline []byte
closed bool
}
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
func (w *testStreamingLogWriter) Close() error {
w.closed = true
return nil
}

View File

@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
return
}
// Sanitize request body: remove thinking blocks with invalid signatures
// to prevent upstream API 400 errors
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
// Restore the body for the handler to read
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
@@ -259,10 +263,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
} else if len(providers) > 0 {
// Log: Using local provider (free)
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
// Wrap with ResponseRewriter for local providers too, because upstream
// proxies (e.g. NewAPI) may return a different model name and lack
// Amp-required fields like thinking.signature.
rewriter := NewResponseRewriter(c.Writer, modelName)
c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
rewriter.Flush()
} else {
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))

View File

@@ -2,6 +2,7 @@ package amp
import (
"bytes"
"fmt"
"net/http"
"strings"
@@ -12,20 +13,23 @@ import (
)
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
// It's used to rewrite model names in responses when model mapping is used
// It is used to rewrite model names in responses when model mapping is used
// and to keep Amp-compatible response shapes.
type ResponseRewriter struct {
gin.ResponseWriter
body *bytes.Buffer
originalModel string
isStreaming bool
body *bytes.Buffer
originalModel string
isStreaming bool
suppressedContentBlock map[int]struct{}
}
// NewResponseRewriter creates a new response rewriter for model name substitution
// NewResponseRewriter creates a new response rewriter for model name substitution.
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
return &ResponseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
suppressedContentBlock: make(map[int]struct{}),
}
}
@@ -33,15 +37,15 @@ const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
func looksLikeSSEChunk(data []byte) bool {
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
// Heuristics are intentionally simple and cheap.
return bytes.Contains(data, []byte("data:")) ||
bytes.Contains(data, []byte("event:")) ||
bytes.Contains(data, []byte("message_start")) ||
bytes.Contains(data, []byte("message_delta")) ||
bytes.Contains(data, []byte("content_block_start")) ||
bytes.Contains(data, []byte("content_block_delta")) ||
bytes.Contains(data, []byte("content_block_stop")) ||
bytes.Contains(data, []byte("\n\n"))
// We conservatively detect SSE by checking for "data:" / "event:" at the start of any line.
for _, line := range bytes.Split(data, []byte("\n")) {
trimmed := bytes.TrimSpace(line)
if bytes.HasPrefix(trimmed, []byte("data:")) ||
bytes.HasPrefix(trimmed, []byte("event:")) {
return true
}
}
return false
}
func (rw *ResponseRewriter) enableStreaming(reason string) error {
@@ -106,7 +110,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
return rw.body.Write(data)
}
// Flush writes the buffered response with model names rewritten
func (rw *ResponseRewriter) Flush() {
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
@@ -115,26 +118,68 @@ func (rw *ResponseRewriter) Flush() {
return
}
if rw.body.Len() > 0 {
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
// Update Content-Length to match the rewritten body size, since
// signature injection and model name changes alter the payload length.
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
for index, block := range gjson.GetBytes(data, "content").Array() {
blockType := block.Get("type").String()
if blockType != "tool_use" && blockType != "thinking" {
continue
}
signaturePath := fmt.Sprintf("content.%d.signature", index)
if gjson.GetBytes(data, signaturePath).Exists() {
continue
}
var err error
data, err = sjson.SetBytes(data, signaturePath, "")
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
break
}
}
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
var err error
data, err = sjson.SetBytes(data, "content_block.signature", "")
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
}
}
return data
}
func (rw *ResponseRewriter) markSuppressedContentBlock(index int) {
if rw.suppressedContentBlock == nil {
rw.suppressedContentBlock = make(map[int]struct{})
}
rw.suppressedContentBlock[index] = struct{}{}
}
func (rw *ResponseRewriter) isSuppressedContentBlock(index int) bool {
_, ok := rw.suppressedContentBlock[index]
return ok
}
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount {
var err error
data, err = sjson.SetBytes(data, "content", filtered.Value())
@@ -142,13 +187,41 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
}
}
}
}
eventType := gjson.GetBytes(data, "type").String()
indexResult := gjson.GetBytes(data, "index")
if eventType == "content_block_start" && gjson.GetBytes(data, "content_block.type").String() == "thinking" && indexResult.Exists() {
rw.markSuppressedContentBlock(int(indexResult.Int()))
return nil
}
if gjson.GetBytes(data, "delta.type").String() == "thinking_delta" {
if indexResult.Exists() {
rw.markSuppressedContentBlock(int(indexResult.Int()))
}
return nil
}
if eventType == "content_block_stop" && indexResult.Exists() {
index := int(indexResult.Int())
if rw.isSuppressedContentBlock(index) {
delete(rw.suppressedContentBlock, index)
return nil
}
}
return data
}
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
}
if rw.originalModel == "" {
return data
}
@@ -160,24 +233,148 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.originalModel == "" {
return chunk
lines := bytes.Split(chunk, []byte("\n"))
var out [][]byte
i := 0
for i < len(lines) {
line := lines[i]
trimmed := bytes.TrimSpace(line)
// Case 1: "event:" line - look ahead for its "data:" line
if bytes.HasPrefix(trimmed, []byte("event: ")) {
// Scan forward past blank lines to find the data: line
dataIdx := -1
for j := i + 1; j < len(lines); j++ {
t := bytes.TrimSpace(lines[j])
if len(t) == 0 {
continue
}
if bytes.HasPrefix(t, []byte("data: ")) {
dataIdx = j
}
break
}
if dataIdx >= 0 {
// Found event+data pair - process through rewriter
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
rewritten := rw.rewriteStreamEvent(jsonData)
// Emit event line
out = append(out, line)
// Emit blank lines between event and data
for k := i + 1; k < dataIdx; k++ {
out = append(out, lines[k])
}
// Emit rewritten data
out = append(out, append([]byte("data: "), rewritten...))
i = dataIdx + 1
continue
}
}
// No data line found (orphan event from cross-chunk split)
// Pass it through as-is - the data will arrive in the next chunk
out = append(out, line)
i++
continue
}
// Case 2: standalone "data:" line (no preceding event: in this chunk)
if bytes.HasPrefix(trimmed, []byte("data: ")) {
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
rewritten := rw.rewriteStreamEvent(jsonData)
out = append(out, append([]byte("data: "), rewritten...))
i++
continue
}
}
// Case 3: everything else
out = append(out, line)
i++
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
return bytes.Join(out, []byte("\n"))
}
// rewriteStreamEvent processes a single JSON event in the SSE stream.
// It rewrites model names and ensures signature fields exist.
// NOTE: streaming mode does NOT suppress thinking blocks - they are
// passed through with signature injection to avoid breaking SSE index
// alignment and TUI rendering.
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Inject empty signature where needed
data = ensureAmpSignature(data)
// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.originalModel)
}
}
}
return bytes.Join(lines, []byte("\n"))
return data
}
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
// from the messages array in a request body before forwarding to the upstream API.
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
func SanitizeAmpRequestBody(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
modified := false
for msgIdx, msg := range messages.Array() {
if msg.Get("role").String() != "assistant" {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
var keepBlocks []interface{}
removedCount := 0
for _, block := range content.Array() {
blockType := block.Get("type").String()
if blockType == "thinking" {
sig := block.Get("signature")
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
removedCount++
continue
}
}
keepBlocks = append(keepBlocks, block.Value())
}
if removedCount > 0 {
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
var err error
if len(keepBlocks) == 0 {
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
} else {
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
}
if err != nil {
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
continue
}
modified = true
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
}
}
if modified {
log.Debugf("Amp RequestSanitizer: sanitized request body")
}
return body
}

View File

@@ -1,6 +1,7 @@
package amp
import (
"strings"
"testing"
)
@@ -100,6 +101,50 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
}
}
func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) {
rw := &ResponseRewriter{suppressedContentBlock: make(map[int]struct{})}
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
result := rw.rewriteStreamChunk(chunk)
// Streaming mode preserves thinking blocks (does NOT suppress them)
// to avoid breaking SSE index alignment and TUI rendering
if !contains(result, []byte(`"content_block":{"type":"thinking"`)) {
t.Fatalf("expected thinking content_block_start to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"delta":{"type":"thinking_delta"`)) {
t.Fatalf("expected thinking_delta to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"type":"content_block_stop","index":0`)) {
t.Fatalf("expected content_block_stop for thinking block to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"content_block":{"type":"tool_use"`)) {
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
}
// Signature should be injected into both thinking and tool_use blocks
if count := strings.Count(string(result), `"signature":""`); count != 2 {
t.Fatalf("expected 2 signature injections, but got %d in %s", count, string(result))
}
}
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte("drop-whitespace")) {
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
}
if contains(result, []byte("drop-number")) {
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
}
if !contains(result, []byte("keep-valid")) {
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
}
if !contains(result, []byte("keep-text")) {
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {

View File

@@ -682,6 +682,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)

View File

@@ -172,6 +172,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
true,
"issue-1711",
time.Now(),

View File

@@ -88,7 +88,7 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
"client_id": {ClientID},
"response_type": {"code"},
"redirect_uri": {RedirectURI},
"scope": {"org:create_api_key user:profile user:inference"},
"scope": {"user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"},
"code_challenge": {pkceCodes.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},

View File

@@ -0,0 +1,33 @@
package cursor
import (
"fmt"
"strings"
)
// CredentialFileName returns the filename used to persist Cursor credentials.
// Priority: explicit label > auto-generated from JWT sub hash.
// If both label and subHash are empty, falls back to "cursor.json".
func CredentialFileName(label, subHash string) string {
label = strings.TrimSpace(label)
subHash = strings.TrimSpace(subHash)
if label != "" {
return fmt.Sprintf("cursor.%s.json", label)
}
if subHash != "" {
return fmt.Sprintf("cursor.%s.json", subHash)
}
return "cursor.json"
}
// DisplayLabel returns a human-readable label for the Cursor account.
func DisplayLabel(label, subHash string) string {
label = strings.TrimSpace(label)
if label != "" {
return "Cursor " + label
}
if subHash != "" {
return "Cursor " + subHash
}
return "Cursor User"
}

View File

@@ -0,0 +1,249 @@
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
package cursor
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"time"
)
const (
CursorLoginURL = "https://cursor.com/loginDeepControl"
CursorPollURL = "https://api2.cursor.sh/auth/poll"
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
pollMaxAttempts = 150
pollBaseDelay = 1 * time.Second
pollMaxDelay = 10 * time.Second
pollBackoffMultiply = 1.2
maxConsecutiveErrors = 10
)
// AuthParams holds the PKCE parameters for Cursor login.
type AuthParams struct {
Verifier string
Challenge string
UUID string
LoginURL string
}
// TokenPair holds the access and refresh tokens from Cursor.
type TokenPair struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
// GeneratePKCE creates a PKCE verifier and challenge pair.
func GeneratePKCE() (verifier, challenge string, err error) {
verifierBytes := make([]byte, 96)
if _, err = rand.Read(verifierBytes); err != nil {
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
}
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}
// GenerateAuthParams creates the full set of auth params for Cursor login.
func GenerateAuthParams() (*AuthParams, error) {
verifier, challenge, err := GeneratePKCE()
if err != nil {
return nil, err
}
uuidBytes := make([]byte, 16)
if _, err = rand.Read(uuidBytes); err != nil {
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
}
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
CursorLoginURL, challenge, uuid)
return &AuthParams{
Verifier: verifier,
Challenge: challenge,
UUID: uuid,
LoginURL: loginURL,
}, nil
}
// PollForAuth polls the Cursor auth endpoint until the user completes login.
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
delay := pollBaseDelay
consecutiveErrors := 0
client := &http.Client{Timeout: 10 * time.Second}
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(delay):
}
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
consecutiveErrors++
if consecutiveErrors >= maxConsecutiveErrors {
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
}
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
continue
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// Still waiting for user to authorize
consecutiveErrors = 0
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
continue
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
var tokens TokenPair
if err := json.Unmarshal(body, &tokens); err != nil {
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
}
return &tokens, nil
}
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
}
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
}
// RefreshToken refreshes a Cursor access token using the refresh token.
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
strings.NewReader("{}"))
if err != nil {
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+refreshToken)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
}
var tokens TokenPair
if err := json.Unmarshal(body, &tokens); err != nil {
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
}
// Keep original refresh token if not returned
if tokens.RefreshToken == "" {
tokens.RefreshToken = refreshToken
}
return &tokens, nil
}
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
// the account. Returns empty string if parsing fails.
func ParseJWTSub(token string) string {
decoded := decodeJWTPayload(token)
if decoded == nil {
return ""
}
var claims struct {
Sub string `json:"sub"`
}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
return claims.Sub
}
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
// e.g. "auth0|user_2x..." → "a3f8b2c1"
func SubToShortHash(sub string) string {
if sub == "" {
return ""
}
h := sha256.Sum256([]byte(sub))
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
}
// decodeJWTPayload decodes the payload (middle) part of a JWT.
func decodeJWTPayload(token string) []byte {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
payload = strings.ReplaceAll(payload, "-", "+")
payload = strings.ReplaceAll(payload, "_", "/")
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil
}
return decoded
}
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
// Falls back to 1 hour from now if the token can't be parsed.
func GetTokenExpiry(token string) time.Time {
decoded := decodeJWTPayload(token)
if decoded == nil {
return time.Now().Add(1 * time.Hour)
}
var claims struct {
Exp float64 `json:"exp"`
}
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
return time.Now().Add(1 * time.Hour)
}
sec, frac := math.Modf(claims.Exp)
expiry := time.Unix(int64(sec), int64(frac*1e9))
// Subtract 5-minute safety margin
return expiry.Add(-5 * time.Minute)
}
func minDuration(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,84 @@
package proto
import (
"encoding/binary"
"encoding/json"
"fmt"
)
const (
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
ConnectEndStreamFlag byte = 0x02
// ConnectCompressionFlag indicates the payload is compressed (not supported).
ConnectCompressionFlag byte = 0x01
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
ConnectFrameHeaderSize = 5
)
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
func FrameConnectMessage(data []byte, flags byte) []byte {
frame := make([]byte, ConnectFrameHeaderSize+len(data))
frame[0] = flags
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
copy(frame[5:], data)
return frame
}
// ParseConnectFrame extracts one frame from a buffer.
// Returns (flags, payload, bytesConsumed, ok).
// ok is false when the buffer is too short for a complete frame.
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
if len(buf) < ConnectFrameHeaderSize {
return 0, nil, 0, false
}
flags = buf[0]
length := binary.BigEndian.Uint32(buf[1:5])
total := ConnectFrameHeaderSize + int(length)
if len(buf) < total {
return 0, nil, 0, false
}
return flags, buf[5:total], total, true
}
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
// The Code field contains the server-defined error code (e.g. gRPC standard codes
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
type ConnectError struct {
Code string // server-defined error code
Message string // human-readable error description
}
func (e *ConnectError) Error() string {
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
}
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
// Returns nil if there is no error in the trailer.
// On error, returns a *ConnectError with the server's error code and message.
func ParseConnectEndStream(data []byte) error {
if len(data) == 0 {
return nil
}
var trailer struct {
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
if err := json.Unmarshal(data, &trailer); err != nil {
return fmt.Errorf("failed to parse Connect end stream: %w", err)
}
if trailer.Error != nil {
code := trailer.Error.Code
if code == "" {
code = "unknown"
}
msg := trailer.Error.Message
if msg == "" {
msg = "Unknown error"
}
return &ConnectError{Code: code, Message: msg}
}
return nil
}

View File

@@ -0,0 +1,564 @@
package proto
import (
"encoding/hex"
"fmt"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protowire"
)
// ServerMessageType identifies the kind of decoded server message.
type ServerMessageType int
const (
ServerMsgUnknown ServerMessageType = iota
ServerMsgTextDelta // Text content delta
ServerMsgThinkingDelta // Thinking/reasoning delta
ServerMsgThinkingCompleted // Thinking completed
ServerMsgKvGetBlob // Server wants a blob
ServerMsgKvSetBlob // Server wants to store a blob
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
ServerMsgExecMcpArgs // Server wants MCP tool execution
ServerMsgExecShellArgs // Rejected: shell command
ServerMsgExecReadArgs // Rejected: file read
ServerMsgExecWriteArgs // Rejected: file write
ServerMsgExecDeleteArgs // Rejected: file delete
ServerMsgExecLsArgs // Rejected: directory listing
ServerMsgExecGrepArgs // Rejected: grep search
ServerMsgExecFetchArgs // Rejected: HTTP fetch
ServerMsgExecDiagnostics // Respond with empty diagnostics
ServerMsgExecShellStream // Rejected: shell stream
ServerMsgExecBgShellSpawn // Rejected: background shell
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
ServerMsgExecOther // Other exec types (respond with empty)
ServerMsgTurnEnded // Turn has ended (no more output)
ServerMsgHeartbeat // Server heartbeat
ServerMsgTokenDelta // Token usage delta
ServerMsgCheckpoint // Conversation checkpoint update
)
// DecodedServerMessage holds parsed data from an AgentServerMessage.
type DecodedServerMessage struct {
Type ServerMessageType
// For text/thinking deltas
Text string
// For KV messages
KvId uint32
BlobId []byte // hex-encoded blob ID
BlobData []byte // for setBlobArgs
// For exec messages
ExecMsgId uint32
ExecId string
// For MCP args
McpToolName string
McpToolCallId string
McpArgs map[string][]byte // arg name -> protobuf-encoded value
// For rejection context
Path string
Command string
WorkingDirectory string
Url string
// For other exec - the raw field number for building a response
ExecFieldNumber int
// For TokenDeltaUpdate
TokenDelta int64
// For conversation checkpoint update (raw bytes, not decoded)
CheckpointData []byte
}
// DecodeAgentServerMessage parses an AgentServerMessage and returns
// a structured representation of the first meaningful message found.
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return msg, fmt.Errorf("invalid tag")
}
data = data[n:]
switch typ {
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return msg, fmt.Errorf("invalid bytes field %d", num)
}
data = data[n:]
// Debug: log top-level ASM fields
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
switch num {
case ASM_InteractionUpdate:
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
decodeInteractionUpdate(val, msg)
case ASM_ExecServerMessage:
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
decodeExecServerMessage(val, msg)
case ASM_KvServerMessage:
decodeKvServerMessage(val, msg)
case ASM_ConversationCheckpoint:
msg.Type = ServerMsgCheckpoint
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
}
case protowire.VarintType:
_, n := protowire.ConsumeVarint(data)
if n < 0 {
return msg, fmt.Errorf("invalid varint field %d", num)
}
data = data[n:]
default:
// Skip unknown wire types
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return msg, fmt.Errorf("invalid field %d", num)
}
data = data[n:]
}
}
return msg, nil
}
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
return
}
data = data[n:]
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
return
}
data = data[n:]
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
switch num {
case IU_TextDelta:
msg.Type = ServerMsgTextDelta
msg.Text = decodeStringField(val, TDU_Text)
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
case IU_ThinkingDelta:
msg.Type = ServerMsgThinkingDelta
msg.Text = decodeStringField(val, TKD_Text)
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
case IU_ThinkingCompleted:
msg.Type = ServerMsgThinkingCompleted
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
case 2:
// tool_call_started - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
case 3:
// tool_call_completed - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
case 8:
// token_delta - extract token count
msg.Type = ServerMsgTokenDelta
msg.TokenDelta = decodeVarintField(val, 1)
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
case 13:
// heartbeat from server
msg.Type = ServerMsgHeartbeat
case 14:
// turn_ended - critical: model finished generating
msg.Type = ServerMsgTurnEnded
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
case 16:
// step_started - ignore
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
case 17:
// step_completed - ignore
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
default:
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
switch typ {
case protowire.VarintType:
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return
}
data = data[n:]
if num == KSM_Id {
msg.KvId = uint32(val)
}
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case KSM_GetBlobArgs:
msg.Type = ServerMsgKvGetBlob
msg.BlobId = decodeBytesField(val, GBA_BlobId)
case KSM_SetBlobArgs:
msg.Type = ServerMsgKvSetBlob
decodeSetBlobArgs(val, msg)
}
default:
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case SBA_BlobId:
msg.BlobId = val
case SBA_BlobData:
msg.BlobData = val
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
switch typ {
case protowire.VarintType:
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return
}
data = data[n:]
if num == ESM_Id {
msg.ExecMsgId = uint32(val)
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
}
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
// Debug: log all fields found in ExecServerMessage
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
switch num {
case ESM_ExecId:
msg.ExecId = string(val)
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
case ESM_RequestContextArgs:
msg.Type = ServerMsgExecRequestCtx
case ESM_McpArgs:
msg.Type = ServerMsgExecMcpArgs
decodeMcpArgs(val, msg)
case ESM_ShellArgs:
msg.Type = ServerMsgExecShellArgs
decodeShellArgs(val, msg)
case ESM_ShellStreamArgs:
msg.Type = ServerMsgExecShellStream
decodeShellArgs(val, msg)
case ESM_ReadArgs:
msg.Type = ServerMsgExecReadArgs
msg.Path = decodeStringField(val, RA_Path)
case ESM_WriteArgs:
msg.Type = ServerMsgExecWriteArgs
msg.Path = decodeStringField(val, WA_Path)
case ESM_DeleteArgs:
msg.Type = ServerMsgExecDeleteArgs
msg.Path = decodeStringField(val, DA_Path)
case ESM_LsArgs:
msg.Type = ServerMsgExecLsArgs
msg.Path = decodeStringField(val, LA_Path)
case ESM_GrepArgs:
msg.Type = ServerMsgExecGrepArgs
case ESM_FetchArgs:
msg.Type = ServerMsgExecFetchArgs
msg.Url = decodeStringField(val, FA_Url)
case ESM_DiagnosticsArgs:
msg.Type = ServerMsgExecDiagnostics
case ESM_BackgroundShellSpawn:
msg.Type = ServerMsgExecBgShellSpawn
decodeShellArgs(val, msg) // same structure
case ESM_WriteShellStdinArgs:
msg.Type = ServerMsgExecWriteShellStdin
default:
// Unknown exec types - only set if we haven't identified the type yet
// (other fields like span_context (19) come after the exec type field)
if msg.Type == ServerMsgUnknown {
msg.Type = ServerMsgExecOther
msg.ExecFieldNumber = int(num)
}
}
default:
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
msg.McpArgs = make(map[string][]byte)
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case MCA_Name:
msg.McpToolName = string(val)
case MCA_Args:
// Map entries are encoded as submessages with key=1, value=2
decodeMapEntry(val, msg.McpArgs)
case MCA_ToolCallId:
msg.McpToolCallId = string(val)
case MCA_ToolName:
// ToolName takes precedence if present
if msg.McpToolName == "" || string(val) != "" {
msg.McpToolName = string(val)
}
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeMapEntry(data []byte, m map[string][]byte) {
var key string
var value []byte
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
if num == 1 {
key = string(val)
} else if num == 2 {
value = append([]byte(nil), val...)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
if key != "" {
m[key] = value
}
}
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case SHA_Command:
msg.Command = string(val)
case SHA_WorkingDirectory:
msg.WorkingDirectory = string(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
// --- Helper decoders ---
// decodeStringField extracts a string from the first matching field in a submessage.
func decodeStringField(data []byte, targetField protowire.Number) string {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return ""
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return ""
}
data = data[n:]
if num == targetField {
return string(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return ""
}
data = data[n:]
}
}
return ""
}
// decodeBytesField extracts bytes from the first matching field in a submessage.
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return nil
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return nil
}
data = data[n:]
if num == targetField {
return append([]byte(nil), val...)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return nil
}
data = data[n:]
}
}
return nil
}
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return 0
}
data = data[n:]
if typ == protowire.VarintType {
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return 0
}
data = data[n:]
if num == targetField {
return int64(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return 0
}
data = data[n:]
}
}
return 0
}
// BlobIdHex returns the hex string of a blob ID for use as a map key.
func BlobIdHex(blobId []byte) string {
return hex.EncodeToString(blobId)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,664 @@
// Package proto provides protobuf encoding for Cursor's gRPC API,
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
package proto
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
"google.golang.org/protobuf/types/known/structpb"
)
// --- Public types ---
// RunRequestParams holds all data needed to build an AgentRunRequest.
type RunRequestParams struct {
ModelId string
SystemPrompt string
UserText string
MessageId string
ConversationId string
Images []ImageData
Turns []TurnData
McpTools []McpToolDef
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
}
type ImageData struct {
MimeType string
Data []byte
}
type TurnData struct {
UserText string
AssistantText string
}
type McpToolDef struct {
Name string
Description string
InputSchema json.RawMessage
}
// --- Helper: create a dynamic message and set fields ---
func newMsg(name string) *dynamicpb.Message {
return dynamicpb.NewMessage(Msg(name))
}
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
}
func setStr(msg *dynamicpb.Message, name, val string) {
if val != "" {
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
}
}
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
if len(val) > 0 {
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
}
}
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
}
func setBool(msg *dynamicpb.Message, name string, val bool) {
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
}
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
}
func marshal(msg *dynamicpb.Message) []byte {
b, err := proto.Marshal(msg)
if err != nil {
panic("cursor proto marshal: " + err.Error())
}
return b
}
// --- Encode functions mirroring cursor-fetch.ts ---
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
func EncodeHeartbeat() []byte {
hb := newMsg("ClientHeartbeat")
acm := newMsg("AgentClientMessage")
setMsg(acm, "client_heartbeat", hb)
return marshal(acm)
}
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
// Mirrors buildCursorRequest() in cursor-fetch.ts.
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
func EncodeRunRequest(p *RunRequestParams) []byte {
if p.RawCheckpoint != nil {
return encodeRunRequestWithCheckpoint(p)
}
if p.BlobStore == nil {
p.BlobStore = make(map[string][]byte)
}
// --- Conversation turns ---
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
var turnBytes [][]byte
for _, turn := range p.Turns {
// UserMessage for this turn
um := newMsg("UserMessage")
setStr(um, "text", turn.UserText)
setStr(um, "message_id", generateId())
umBytes := marshal(um)
// Steps (assistant response)
var stepBytes [][]byte
if turn.AssistantText != "" {
am := newMsg("AssistantMessage")
setStr(am, "text", turn.AssistantText)
step := newMsg("ConversationStep")
setMsg(step, "assistant_message", am)
stepBytes = append(stepBytes, marshal(step))
}
// AgentConversationTurnStructure (fields are bytes, not submessages)
agentTurn := newMsg("AgentConversationTurnStructure")
setBytes(agentTurn, "user_message", umBytes)
for _, sb := range stepBytes {
stepsField := field(agentTurn, "steps")
list := agentTurn.Mutable(stepsField).List()
list.Append(protoreflect.ValueOfBytes(sb))
}
// ConversationTurnStructure (oneof turn → agentConversationTurn)
cts := newMsg("ConversationTurnStructure")
setMsg(cts, "agent_conversation_turn", agentTurn)
turnBytes = append(turnBytes, marshal(cts))
}
// --- System prompt blob ---
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
blobId := sha256Sum(systemJSON)
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
// --- ConversationStateStructure ---
css := newMsg("ConversationStateStructure")
// rootPromptMessagesJson: repeated bytes
rootField := field(css, "root_prompt_messages_json")
rootList := css.Mutable(rootField).List()
rootList.Append(protoreflect.ValueOfBytes(blobId))
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
turnsField := field(css, "turns")
turnsList := css.Mutable(turnsField).List()
for _, tb := range turnBytes {
turnsList.Append(protoreflect.ValueOfBytes(tb))
}
turnsOldField := field(css, "turns_old")
if turnsOldField != nil {
turnsOldList := css.Mutable(turnsOldField).List()
for _, tb := range turnBytes {
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
}
}
// --- UserMessage (current) ---
userMessage := newMsg("UserMessage")
setStr(userMessage, "text", p.UserText)
setStr(userMessage, "message_id", p.MessageId)
// Images via SelectedContext
if len(p.Images) > 0 {
sc := newMsg("SelectedContext")
imgsField := field(sc, "selected_images")
imgsList := sc.Mutable(imgsField).List()
for _, img := range p.Images {
si := newMsg("SelectedImage")
setStr(si, "uuid", generateId())
setStr(si, "mime_type", img.MimeType)
setBytes(si, "data", img.Data)
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
}
setMsg(userMessage, "selected_context", sc)
}
// --- UserMessageAction ---
uma := newMsg("UserMessageAction")
setMsg(uma, "user_message", userMessage)
// --- ConversationAction ---
ca := newMsg("ConversationAction")
setMsg(ca, "user_message_action", uma)
// --- ModelDetails ---
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
// --- AgentRunRequest ---
arr := newMsg("AgentRunRequest")
setMsg(arr, "conversation_state", css)
setMsg(arr, "action", ca)
setMsg(arr, "model_details", md)
setStr(arr, "conversation_id", p.ConversationId)
// McpTools
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
setMsg(arr, "mcp_tools", mcpTools)
}
// --- AgentClientMessage ---
acm := newMsg("AgentClientMessage")
setMsg(acm, "run_request", arr)
return marshal(acm)
}
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
// Build UserMessage
userMessage := newMsg("UserMessage")
setStr(userMessage, "text", p.UserText)
setStr(userMessage, "message_id", p.MessageId)
if len(p.Images) > 0 {
sc := newMsg("SelectedContext")
imgsField := field(sc, "selected_images")
imgsList := sc.Mutable(imgsField).List()
for _, img := range p.Images {
si := newMsg("SelectedImage")
setStr(si, "uuid", generateId())
setStr(si, "mime_type", img.MimeType)
setBytes(si, "data", img.Data)
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
}
setMsg(userMessage, "selected_context", sc)
}
// Build ConversationAction with UserMessageAction
uma := newMsg("UserMessageAction")
setMsg(uma, "user_message", userMessage)
ca := newMsg("ConversationAction")
setMsg(ca, "user_message_action", uma)
caBytes := marshal(ca)
// Build ModelDetails
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
mdBytes := marshal(md)
// Build McpTools
var mcpToolsBytes []byte
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
mcpToolsBytes = marshal(mcpTools)
}
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
var arrBuf []byte
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
// field 2: action = ConversationAction
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
// field 3: model_details = ModelDetails
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
// field 4: mcp_tools = McpTools
if len(mcpToolsBytes) > 0 {
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
}
// field 5: conversation_id = string
if p.ConversationId != "" {
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
}
// Wrap in AgentClientMessage field 1 (run_request)
var acmBuf []byte
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
return acmBuf
}
// ResumeRequestParams holds data for a ResumeAction request.
type ResumeRequestParams struct {
ModelId string
ConversationId string
McpTools []McpToolDef
}
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
// Used to resume a conversation by conversation_id without re-sending full history.
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
// RequestContext with tools
rc := newMsg("RequestContext")
if len(p.McpTools) > 0 {
toolsField := field(rc, "tools")
toolsList := rc.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
}
// ResumeAction
ra := newMsg("ResumeAction")
setMsg(ra, "request_context", rc)
// ConversationAction with resume_action
ca := newMsg("ConversationAction")
setMsg(ca, "resume_action", ra)
// ModelDetails
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
// AgentRunRequest — no conversation_state needed for resume
arr := newMsg("AgentRunRequest")
setMsg(arr, "action", ca)
setMsg(arr, "model_details", md)
setStr(arr, "conversation_id", p.ConversationId)
// McpTools at top level
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
setMsg(arr, "mcp_tools", mcpTools)
}
acm := newMsg("AgentClientMessage")
setMsg(acm, "run_request", arr)
return marshal(acm)
}
// --- KV response encoders ---
// Mirrors handleKvMessage() in cursor-fetch.ts
// EncodeKvGetBlobResult responds to a getBlobArgs request.
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
result := newMsg("GetBlobResult")
if blobData != nil {
setBytes(result, "blob_data", blobData)
}
kvc := newMsg("KvClientMessage")
setUint32(kvc, "id", kvId)
setMsg(kvc, "get_blob_result", result)
acm := newMsg("AgentClientMessage")
setMsg(acm, "kv_client_message", kvc)
return marshal(acm)
}
// EncodeKvSetBlobResult responds to a setBlobArgs request.
func EncodeKvSetBlobResult(kvId uint32) []byte {
result := newMsg("SetBlobResult")
kvc := newMsg("KvClientMessage")
setUint32(kvc, "id", kvId)
setMsg(kvc, "set_blob_result", result)
acm := newMsg("AgentClientMessage")
setMsg(acm, "kv_client_message", kvc)
return marshal(acm)
}
// --- Exec response encoders ---
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
// RequestContext with tools
rc := newMsg("RequestContext")
if len(tools) > 0 {
toolsField := field(rc, "tools")
toolsList := rc.Mutable(toolsField).List()
for _, tool := range tools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
}
// RequestContextSuccess
rcs := newMsg("RequestContextSuccess")
setMsg(rcs, "request_context", rc)
// RequestContextResult (oneof success)
rcr := newMsg("RequestContextResult")
setMsg(rcr, "success", rcs)
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
}
// EncodeExecMcpResult responds with MCP tool result.
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
textContent := newMsg("McpTextContent")
setStr(textContent, "text", content)
contentItem := newMsg("McpToolResultContentItem")
setMsg(contentItem, "text", textContent)
success := newMsg("McpSuccess")
contentField := field(success, "content")
contentList := success.Mutable(contentField).List()
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
setBool(success, "is_error", isError)
result := newMsg("McpResult")
setMsg(result, "success", success)
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
}
// EncodeExecMcpError responds with MCP error.
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
mcpErr := newMsg("McpError")
setStr(mcpErr, "error", errMsg)
result := newMsg("McpResult")
setMsg(result, "error", mcpErr)
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
}
// --- Rejection encoders (mirror handleExecMessage rejections) ---
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("ReadRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("ReadResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
}
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
rej := newMsg("ShellRejected")
setStr(rej, "command", command)
setStr(rej, "working_directory", workDir)
setStr(rej, "reason", reason)
result := newMsg("ShellResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
}
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("WriteRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("WriteResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
}
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("DeleteRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("DeleteResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
}
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("LsRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("LsResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
}
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
grepErr := newMsg("GrepError")
setStr(grepErr, "error", errMsg)
result := newMsg("GrepResult")
setMsg(result, "error", grepErr)
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
}
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
fetchErr := newMsg("FetchError")
setStr(fetchErr, "url", url)
setStr(fetchErr, "error", errMsg)
result := newMsg("FetchResult")
setMsg(result, "error", fetchErr)
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
}
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
result := newMsg("DiagnosticsResult")
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
}
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
rej := newMsg("ShellRejected")
setStr(rej, "command", command)
setStr(rej, "working_directory", workDir)
setStr(rej, "reason", reason)
result := newMsg("BackgroundShellSpawnResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
}
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
wsErr := newMsg("WriteShellStdinError")
setStr(wsErr, "error", errMsg)
result := newMsg("WriteShellStdinResult")
setMsg(result, "error", wsErr)
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
}
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
// Mirrors sendExec() in cursor-fetch.ts.
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
ecm := newMsg("ExecClientMessage")
setUint32(ecm, "id", id)
// Force set exec_id even if empty - Cursor requires this field to be set
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
// Debug: check if field exists
fd := field(ecm, resultFieldName)
if fd == nil {
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
}
// Debug: log the actual field being set
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
acm := newMsg("AgentClientMessage")
setMsg(acm, "exec_client_message", ecm)
return marshal(acm)
}
func listFields(msg *dynamicpb.Message) []string {
var names []string
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
}
return names
}
// --- Utilities ---
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
if len(jsonData) == 0 {
return nil
}
var v interface{}
if err := json.Unmarshal(jsonData, &v); err != nil {
return jsonData // fallback to raw JSON if parsing fails
}
pbVal, err := structpb.NewValue(v)
if err != nil {
return jsonData // fallback
}
b, err := proto.Marshal(pbVal)
if err != nil {
return jsonData // fallback
}
return b
}
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
val := &structpb.Value{}
if err := proto.Unmarshal(data, val); err != nil {
return nil, err
}
return val.AsInterface(), nil
}
func sha256Sum(data []byte) []byte {
h := sha256.Sum256(data)
return h[:]
}
var idCounter uint64
func generateId() string {
idCounter++
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
return hex.EncodeToString(h[:16])
}

View File

@@ -0,0 +1,332 @@
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
package proto
// AgentClientMessage (msg 118) oneof "message"
const (
ACM_RunRequest = 1 // AgentRunRequest
ACM_ExecClientMessage = 2 // ExecClientMessage
ACM_KvClientMessage = 3 // KvClientMessage
ACM_ConversationAction = 4 // ConversationAction
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
ACM_InteractionResponse = 6 // InteractionResponse
ACM_ClientHeartbeat = 7 // ClientHeartbeat
)
// AgentServerMessage (msg 119) oneof "message"
const (
ASM_InteractionUpdate = 1 // InteractionUpdate
ASM_ExecServerMessage = 2 // ExecServerMessage
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
ASM_KvServerMessage = 4 // KvServerMessage
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
ASM_InteractionQuery = 7 // InteractionQuery
)
// AgentRunRequest (msg 91)
const (
ARR_ConversationState = 1 // ConversationStateStructure
ARR_Action = 2 // ConversationAction
ARR_ModelDetails = 3 // ModelDetails
ARR_McpTools = 4 // McpTools
ARR_ConversationId = 5 // string (optional)
)
// ConversationStateStructure (msg 83)
const (
CSS_RootPromptMessagesJson = 1 // repeated bytes
CSS_TurnsOld = 2 // repeated bytes (deprecated)
CSS_Todos = 3 // repeated bytes
CSS_PendingToolCalls = 4 // repeated string
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
CSS_PreviousWorkspaceUris = 9 // repeated string
CSS_SelfSummaryCount = 17 // uint32
CSS_ReadPaths = 18 // repeated string
)
// ConversationAction (msg 54) oneof "action"
const (
CA_UserMessageAction = 1 // UserMessageAction
)
// UserMessageAction (msg 55)
const (
UMA_UserMessage = 1 // UserMessage
)
// UserMessage (msg 63)
const (
UM_Text = 1 // string
UM_MessageId = 2 // string
UM_SelectedContext = 3 // SelectedContext (optional)
)
// SelectedContext
const (
SC_SelectedImages = 1 // repeated SelectedImage
)
// SelectedImage
const (
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
SI_Uuid = 2 // string
SI_Path = 3 // string
SI_MimeType = 7 // string
SI_Data = 8 // bytes (oneof dataOrBlobId)
)
// ModelDetails (msg 88)
const (
MD_ModelId = 1 // string
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
MD_DisplayModelId = 3 // string
MD_DisplayName = 4 // string
)
// McpTools (msg 307)
const (
MT_McpTools = 1 // repeated McpToolDefinition
)
// McpToolDefinition (msg 306)
const (
MTD_Name = 1 // string
MTD_Description = 2 // string
MTD_InputSchema = 3 // bytes
MTD_ProviderIdentifier = 4 // string
MTD_ToolName = 5 // string
)
// ConversationTurnStructure (msg 70) oneof "turn"
const (
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
)
// AgentConversationTurnStructure (msg 72)
const (
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
)
// ConversationStep (msg 53) oneof "message"
const (
CS_AssistantMessage = 1 // AssistantMessage
)
// AssistantMessage
const (
AM_Text = 1 // string
)
// --- Server-side message fields ---
// InteractionUpdate oneof "message"
const (
IU_TextDelta = 1 // TextDeltaUpdate
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
)
// TextDeltaUpdate (msg 92)
const (
TDU_Text = 1 // string
)
// ThinkingDeltaUpdate (msg 97)
const (
TKD_Text = 1 // string
)
// KvServerMessage (msg 271)
const (
KSM_Id = 1 // uint32
KSM_GetBlobArgs = 2 // GetBlobArgs
KSM_SetBlobArgs = 3 // SetBlobArgs
)
// GetBlobArgs (msg 267)
const (
GBA_BlobId = 1 // bytes
)
// SetBlobArgs (msg 269)
const (
SBA_BlobId = 1 // bytes
SBA_BlobData = 2 // bytes
)
// KvClientMessage (msg 272)
const (
KCM_Id = 1 // uint32
KCM_GetBlobResult = 2 // GetBlobResult
KCM_SetBlobResult = 3 // SetBlobResult
)
// GetBlobResult (msg 268)
const (
GBR_BlobData = 1 // bytes (optional)
)
// ExecServerMessage
const (
ESM_Id = 1 // uint32
ESM_ExecId = 15 // string
// oneof message:
ESM_ShellArgs = 2 // ShellArgs
ESM_WriteArgs = 3 // WriteArgs
ESM_DeleteArgs = 4 // DeleteArgs
ESM_GrepArgs = 5 // GrepArgs
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
ESM_LsArgs = 8 // LsArgs
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
ESM_RequestContextArgs = 10 // RequestContextArgs
ESM_McpArgs = 11 // McpArgs
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
ESM_FetchArgs = 20 // FetchArgs
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
)
// ExecClientMessage
const (
ECM_Id = 1 // uint32
ECM_ExecId = 15 // string
// oneof message (mirrors server fields):
ECM_ShellResult = 2
ECM_WriteResult = 3
ECM_DeleteResult = 4
ECM_GrepResult = 5
ECM_ReadResult = 7
ECM_LsResult = 8
ECM_DiagnosticsResult = 9
ECM_RequestContextResult = 10
ECM_McpResult = 11
ECM_ShellStream = 14
ECM_BackgroundShellSpawnRes = 16
ECM_FetchResult = 20
ECM_WriteShellStdinResult = 23
)
// McpArgs
const (
MCA_Name = 1 // string
MCA_Args = 2 // map<string, bytes>
MCA_ToolCallId = 3 // string
MCA_ProviderIdentifier = 4 // string
MCA_ToolName = 5 // string
)
// RequestContextResult oneof "result"
const (
RCR_Success = 1 // RequestContextSuccess
RCR_Error = 2 // RequestContextError
)
// RequestContextSuccess (msg 337)
const (
RCS_RequestContext = 1 // RequestContext
)
// RequestContext
const (
RC_Rules = 2 // repeated CursorRule
RC_Tools = 7 // repeated McpToolDefinition
)
// McpResult oneof "result"
const (
MCR_Success = 1 // McpSuccess
MCR_Error = 2 // McpError
MCR_Rejected = 3 // McpRejected
)
// McpSuccess (msg 290)
const (
MCS_Content = 1 // repeated McpToolResultContentItem
MCS_IsError = 2 // bool
)
// McpToolResultContentItem oneof "content"
const (
MTRCI_Text = 1 // McpTextContent
)
// McpTextContent (msg 287)
const (
MTC_Text = 1 // string
)
// McpError (msg 291)
const (
MCE_Error = 1 // string
)
// --- Rejection messages ---
// ReadRejected: path=1, reason=2
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
// WriteRejected: path=1, reason=2
// DeleteRejected: path=1, reason=2
// LsRejected: path=1, reason=2
// GrepError: error=1
// FetchError: url=1, error=2
// WriteShellStdinError: error=1
// ReadResult oneof: success=1, error=2, rejected=3
// ShellResult oneof: success=1 (+ various), rejected=?
// The TS code uses specific result field numbers from the oneof:
const (
RR_Rejected = 3 // ReadResult.rejected
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
WR_Rejected = 5 // WriteResult.rejected
DR_Rejected = 3 // DeleteResult.rejected
LR_Rejected = 3 // LsResult.rejected
GR_Error = 2 // GrepResult.error
FR_Error = 2 // FetchResult.error
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
WSSR_Error = 2 // WriteShellStdinResult.error
)
// --- Rejection struct fields ---
const (
REJ_Path = 1
REJ_Reason = 2
SREJ_Command = 1
SREJ_WorkingDir = 2
SREJ_Reason = 3
SREJ_IsReadonly = 4
GERR_Error = 1
FERR_Url = 1
FERR_Error = 2
)
// ReadArgs
const (
RA_Path = 1 // string
)
// WriteArgs
const (
WA_Path = 1 // string
)
// DeleteArgs
const (
DA_Path = 1 // string
)
// LsArgs
const (
LA_Path = 1 // string
)
// ShellArgs
const (
SHA_Command = 1 // string
SHA_WorkingDirectory = 2 // string
)
// FetchArgs
const (
FA_Url = 1 // string
)

View File

@@ -0,0 +1,313 @@
package proto
import (
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
const (
defaultInitialWindowSize = 65535 // HTTP/2 default
maxFramePayload = 16384 // HTTP/2 default max frame size
)
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
type H2Stream struct {
framer *http2.Framer
conn net.Conn
streamID uint32
mu sync.Mutex
id string // unique identifier for debugging
frameNum int64 // sequential frame counter for debugging
dataCh chan []byte
doneCh chan struct{}
err error
// Send-side flow control
sendWindow int32 // available bytes we can send on this stream
connWindow int32 // available bytes on the connection level
windowCond *sync.Cond // signaled when window is updated
windowMu sync.Mutex // protects sendWindow, connWindow
}
// ID returns the unique identifier for this stream (for logging).
func (s *H2Stream) ID() string { return s.id }
// FrameNum returns the current frame number for debugging.
func (s *H2Stream) FrameNum() int64 {
s.mu.Lock()
defer s.mu.Unlock()
return s.frameNum
}
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
NextProtos: []string{"h2"},
})
if err != nil {
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
}
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
tlsConn.Close()
return nil, fmt.Errorf("h2: server did not negotiate h2")
}
framer := http2.NewFramer(tlsConn, tlsConn)
// Client connection preface
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: preface write failed: %w", err)
}
// Send initial SETTINGS (tell server how much WE can receive)
if err := framer.WriteSettings(
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: settings write failed: %w", err)
}
// Connection-level window update (for receiving)
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: window update failed: %w", err)
}
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
// Track server's initial window size (how much WE can send)
serverInitialWindowSize := int32(defaultInitialWindowSize)
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
for i := 0; i < 10; i++ {
f, err := framer.ReadFrame()
if err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
}
switch sf := f.(type) {
case *http2.SettingsFrame:
if !sf.IsAck() {
sf.ForeachSetting(func(s http2.Setting) error {
if s.ID == http2.SettingInitialWindowSize {
serverInitialWindowSize = int32(s.Val)
log.Debugf("h2: server initial window size: %d", s.Val)
}
return nil
})
framer.WriteSettingsAck()
} else {
goto handshakeDone
}
case *http2.WindowUpdateFrame:
if sf.StreamID == 0 {
connWindowSize += int32(sf.Increment)
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
}
default:
// unexpected but continue
}
}
handshakeDone:
// Build HEADERS
streamID := uint32(1)
var hdrBuf []byte
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
if p, ok := headers[":path"]; ok {
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
}
for k, v := range headers {
if len(k) > 0 && k[0] == ':' {
continue
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
if err := framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
BlockFragment: hdrBuf,
EndStream: false,
EndHeaders: true,
}); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: headers write failed: %w", err)
}
s := &H2Stream{
framer: framer,
conn: tlsConn,
streamID: streamID,
dataCh: make(chan []byte, 256),
doneCh: make(chan struct{}),
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
frameNum: 0,
sendWindow: serverInitialWindowSize,
connWindow: connWindowSize,
}
s.windowCond = sync.NewCond(&s.windowMu)
go s.readLoop()
return s, nil
}
// Write sends a DATA frame on the stream, respecting flow control.
func (s *H2Stream) Write(data []byte) error {
for len(data) > 0 {
chunk := data
if len(chunk) > maxFramePayload {
chunk = data[:maxFramePayload]
}
// Wait for flow control window
s.windowMu.Lock()
for s.sendWindow <= 0 || s.connWindow <= 0 {
s.windowCond.Wait()
}
// Limit chunk to available window
allowed := int(s.sendWindow)
if int(s.connWindow) < allowed {
allowed = int(s.connWindow)
}
if len(chunk) > allowed {
chunk = chunk[:allowed]
}
s.sendWindow -= int32(len(chunk))
s.connWindow -= int32(len(chunk))
s.windowMu.Unlock()
s.mu.Lock()
err := s.framer.WriteData(s.streamID, false, chunk)
s.mu.Unlock()
if err != nil {
return err
}
data = data[len(chunk):]
}
return nil
}
// Data returns the channel of received data chunks.
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
// Done returns a channel closed when the stream ends.
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
// Err returns the error (if any) that caused the stream to close.
// Returns nil for a clean shutdown (EOF / StreamEnded).
func (s *H2Stream) Err() error { return s.err }
// Close tears down the connection.
func (s *H2Stream) Close() {
s.conn.Close()
// Unblock any writers waiting on flow control
s.windowCond.Broadcast()
}
func (s *H2Stream) readLoop() {
defer close(s.doneCh)
defer close(s.dataCh)
for {
f, err := s.framer.ReadFrame()
if err != nil {
if err != io.EOF {
s.err = err
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
}
return
}
// Increment frame counter
s.mu.Lock()
s.frameNum++
s.mu.Unlock()
switch frame := f.(type) {
case *http2.DataFrame:
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
cp := make([]byte, len(frame.Data()))
copy(cp, frame.Data())
s.dataCh <- cp
// Flow control: send WINDOW_UPDATE for received data
s.mu.Lock()
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
s.mu.Unlock()
}
if frame.StreamEnded() {
return
}
case *http2.HeadersFrame:
if frame.StreamEnded() {
return
}
case *http2.RSTStreamFrame:
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
return
case *http2.GoAwayFrame:
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
return
case *http2.PingFrame:
if !frame.IsAck() {
s.mu.Lock()
s.framer.WritePing(true, frame.Data)
s.mu.Unlock()
}
case *http2.SettingsFrame:
if !frame.IsAck() {
// Check for window size changes
frame.ForeachSetting(func(setting http2.Setting) error {
if setting.ID == http2.SettingInitialWindowSize {
s.windowMu.Lock()
delta := int32(setting.Val) - s.sendWindow
s.sendWindow += delta
s.windowMu.Unlock()
s.windowCond.Broadcast()
}
return nil
})
s.mu.Lock()
s.framer.WriteSettingsAck()
s.mu.Unlock()
}
case *http2.WindowUpdateFrame:
// Update send-side flow control window
s.windowMu.Lock()
if frame.StreamID == 0 {
s.connWindow += int32(frame.Increment)
} else if frame.StreamID == s.streamID {
s.sendWindow += int32(frame.Increment)
}
s.windowMu.Unlock()
s.windowCond.Broadcast()
}
}
}
type sliceWriter struct{ buf *[]byte }
func (w *sliceWriter) Write(p []byte) (int, error) {
*w.buf = append(*w.buf, p...)
return len(p), nil
}

View File

@@ -24,6 +24,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewKiloAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
sdkAuth.NewCodeBuddyAuthenticator(),
sdkAuth.NewCursorAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,37 @@
package cmd
import (
"context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
}
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
if err != nil {
log.Errorf("Cursor authentication failed: %v", err)
return
}
if savedPath != "" {
log.Infof("Authentication saved to %s", savedPath)
}
if record != nil && record.Label != "" {
log.Infof("Authenticated as %s", record.Label)
}
log.Info("Cursor authentication successful!")
}

View File

@@ -195,6 +195,9 @@ type RemoteManagement struct {
SecretKey string `yaml:"secret-key"`
// DisableControlPanel skips serving and syncing the bundled management UI when true.
DisableControlPanel bool `yaml:"disable-control-panel"`
// DisableAutoUpdatePanel disables automatic periodic background updates of the management panel asset from GitHub.
// When false (the default), the background updater remains enabled; when true, the panel is only downloaded on first access if missing.
DisableAutoUpdatePanel bool `yaml:"disable-auto-update-panel"`
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
PanelGitHubRepository string `yaml:"panel-github-repository"`
@@ -208,6 +211,10 @@ type QuotaExceeded struct {
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
// AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once
// on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"].
AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"`
}
// RoutingConfig configures how credentials are selected for requests.
@@ -254,8 +261,8 @@ type AmpCode struct {
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
@@ -377,6 +384,11 @@ type ClaudeKey struct {
// Cloak configures request cloaking for non-Claude-Code clients.
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
// ExperimentalCCHSigning enables opt-in final-body cch signing for cloaked
// Claude /v1/messages requests. It is disabled by default so upstream seed
// changes do not alter the proxy's legacy behavior.
ExperimentalCCHSigning bool `yaml:"experimental-cch-signing,omitempty" json:"experimental-cch-signing,omitempty"`
}
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }

View File

@@ -4,6 +4,7 @@
package logging
import (
"bufio"
"bytes"
"compress/flate"
"compress/gzip"
@@ -41,15 +42,17 @@ type RequestLogger interface {
// - statusCode: The response status code
// - responseHeaders: The response headers
// - response: The raw response data
// - websocketTimeline: Optional downstream websocket event timeline
// - apiRequest: The API request data
// - apiResponse: The API response data
// - apiWebsocketTimeline: Optional upstream websocket event timeline
// - requestID: Optional request ID for log file naming
// - requestTimestamp: When the request was received
// - apiResponseTimestamp: When the API response was received
//
// Returns:
// - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
//
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
// - error: An error if writing fails, nil otherwise
WriteAPIResponse(apiResponse []byte) error
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
// This should be called when upstream communication happened over websocket.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
//
// Parameters:
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
//
// Returns:
// - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
}
// LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
}
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
if !l.enabled && !force {
return nil
}
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
requestHeaders,
body,
requestBodyPath,
websocketTimeline,
apiRequest,
apiResponse,
apiWebsocketTimeline,
apiResponseErrors,
statusCode,
responseHeaders,
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
requestHeaders map[string][]string,
requestBody []byte,
requestBodyPath string,
websocketTimeline []byte,
apiRequest []byte,
apiResponse []byte,
apiWebsocketTimeline []byte,
apiResponseErrors []*interfaces.ErrorMessage,
statusCode int,
responseHeaders map[string][]string,
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
if requestTimestamp.IsZero() {
requestTimestamp = time.Now()
}
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
return errWrite
}
if isWebsocketTranscript {
// Intentionally omit the generic downstream HTTP response section for websocket
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
// and appending a one-off upgrade response snapshot would dilute that transcript.
return nil
}
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
}
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
body []byte,
bodyPath string,
timestamp time.Time,
downstreamTransport string,
upstreamTransport string,
includeBody bool,
) error {
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
return errWrite
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
return errWrite
}
if strings.TrimSpace(downstreamTransport) != "" {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
return errWrite
}
}
if strings.TrimSpace(upstreamTransport) != "" {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
return errWrite
}
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
}
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
return errWrite
}
if !includeBody {
return nil
}
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
return errWrite
}
bodyTrailingNewlines := 1
if bodyPath != "" {
bodyFile, errOpen := os.Open(bodyPath)
if errOpen != nil {
return errOpen
}
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
tracker := &trailingNewlineTrackingWriter{writer: w}
written, errCopy := io.Copy(tracker, bodyFile)
if errCopy != nil {
_ = bodyFile.Close()
return errCopy
}
if written > 0 {
bodyTrailingNewlines = tracker.trailingNewlines
}
if errClose := bodyFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request body temp file")
}
} else if _, errWrite := w.Write(body); errWrite != nil {
return errWrite
} else if len(body) > 0 {
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
}
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
return errWrite
}
return nil
}
func countTrailingNewlinesBytes(payload []byte) int {
count := 0
for i := len(payload) - 1; i >= 0; i-- {
if payload[i] != '\n' {
break
}
count++
}
return count
}
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
missingNewlines := 3 - trailingNewlines
if missingNewlines <= 0 {
return nil
}
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
return errWrite
}
type trailingNewlineTrackingWriter struct {
writer io.Writer
trailingNewlines int
}
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
written, errWrite := t.writer.Write(payload)
if written > 0 {
writtenPayload := payload[:written]
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
if trailingNewlines == len(writtenPayload) {
t.trailingNewlines += trailingNewlines
} else {
t.trailingNewlines = trailingNewlines
}
}
return written, errWrite
}
func hasSectionPayload(payload []byte) bool {
return len(bytes.TrimSpace(payload)) > 0
}
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
if hasSectionPayload(websocketTimeline) {
return "websocket"
}
for key, values := range headers {
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
for _, value := range values {
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
return "websocket"
}
}
}
}
return "http"
}
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
hasWS := hasSectionPayload(apiWebsocketTimeline)
switch {
case hasHTTP && hasWS:
return "websocket+http"
case hasWS:
return "websocket"
case hasHTTP:
return "http"
default:
return ""
}
}
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
if len(payload) == 0 {
return nil
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
if !bytes.HasSuffix(payload, []byte("\n")) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
} else {
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
return errWrite
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
return errWrite
}
return nil
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
return errWrite
}
trailingNewlines := 1
if apiResponseErrors[i].Error != nil {
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
errText := apiResponseErrors[i].Error.Error()
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
return errWrite
}
if errText != "" {
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
}
}
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
return errWrite
}
}
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
var bufferedReader *bufio.Reader
if responseReader != nil {
bufferedReader = bufio.NewReader(responseReader)
}
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
if responseReader != nil {
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
if bufferedReader != nil {
if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
return errCopy
}
}
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
return nil
}
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
if reader == nil {
return false
}
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
return true
}
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
return true
}
return false
}
// formatLogContent creates the complete log content for non-streaming requests.
//
// Parameters:
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - websocketTimeline: The downstream websocket event timeline
// - apiRequest: The API request data
// - apiResponse: The API response data
// - response: The raw response data
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
//
// Returns:
// - string: The formatted log content
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
var content strings.Builder
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
// Request info
content.WriteString(l.formatRequestInfo(url, method, headers, body))
content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
if len(websocketTimeline) > 0 {
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
content.Write(websocketTimeline)
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
content.Write(websocketTimeline)
content.WriteString("\n")
}
content.WriteString("\n")
}
if len(apiWebsocketTimeline) > 0 {
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
content.Write(apiWebsocketTimeline)
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
content.Write(apiWebsocketTimeline)
content.WriteString("\n")
}
content.WriteString("\n")
}
if len(apiRequest) > 0 {
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
content.WriteString("\n")
}
if isWebsocketTranscript {
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
// timeline sections instead of a generic downstream HTTP response block.
return content.String()
}
// Response section
content.WriteString("=== RESPONSE ===\n")
content.WriteString(fmt.Sprintf("Status: %d\n", status))
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
//
// Returns:
// - string: The formatted request information
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
var content strings.Builder
content.WriteString("=== REQUEST INFO ===\n")
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
content.WriteString(fmt.Sprintf("URL: %s\n", url))
content.WriteString(fmt.Sprintf("Method: %s\n", method))
if strings.TrimSpace(downstreamTransport) != "" {
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
}
if strings.TrimSpace(upstreamTransport) != "" {
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
}
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
content.WriteString("\n")
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
}
content.WriteString("\n")
if !includeBody {
return content.String()
}
content.WriteString("=== REQUEST BODY ===\n")
content.Write(body)
content.WriteString("\n\n")
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
// apiResponse stores the upstream API response data.
apiResponse []byte
// apiWebsocketTimeline stores the upstream websocket event timeline.
apiWebsocketTimeline []byte
// apiResponseTimestamp captures when the API response was received.
apiResponseTimestamp time.Time
}
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
return nil
}
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
if len(apiWebsocketTimeline) == 0 {
return nil
}
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
if !timestamp.IsZero() {
w.apiResponseTimestamp = timestamp
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
// Close finalizes the log file and cleans up resources.
// It writes all buffered data to the file in the correct order:
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
// API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
//
// Returns:
// - error: An error if closing fails, nil otherwise
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
}
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
return nil
}
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
return nil
}
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
// Close is a no-op implementation that does nothing and always returns nil.

View File

@@ -31,6 +31,7 @@ const (
httpUserAgent = "CLIProxyAPI-management-updater"
managementSyncMinInterval = 30 * time.Second
updateCheckInterval = 3 * time.Hour
maxAssetDownloadSize = 50 << 20 // 10 MB safety limit for management asset downloads
)
// ManagementFileName exposes the control panel asset filename.
@@ -88,6 +89,10 @@ func runAutoUpdater(ctx context.Context) {
log.Debug("management asset auto-updater skipped: control panel disabled")
return
}
if cfg.RemoteManagement.DisableAutoUpdatePanel {
log.Debug("management asset auto-updater skipped: disable-auto-update-panel is enabled")
return
}
configPath, _ := schedulerConfigPath.Load().(string)
staticDir := StaticDir(configPath)
@@ -259,7 +264,8 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
}
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
log.Errorf("management asset digest mismatch: expected %s got %s — aborting update for safety", remoteHash, downloadedHash)
return nil, nil
}
if err = atomicWriteFile(localPath, data); err != nil {
@@ -282,6 +288,9 @@ func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, loca
return false
}
log.Warnf("management asset downloaded from fallback URL without digest verification (hash=%s) — "+
"enable verified GitHub updates by keeping disable-auto-update-panel set to false", downloadedHash)
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
@@ -392,10 +401,13 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string)
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
data, err := io.ReadAll(resp.Body)
data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1))
if err != nil {
return nil, "", fmt.Errorf("read download body: %w", err)
}
if int64(len(data)) > maxAssetDownloadSize {
return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize)
}
sum := sha256.Sum256(data)
return data, hex.EncodeToString(sum[:]), nil

View File

@@ -231,11 +231,25 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetAntigravityModels()
case "codebuddy":
return GetCodeBuddyModels()
case "cursor":
return GetCursorModels()
default:
return nil
}
}
// GetCursorModels returns the fallback Cursor model definitions.
func GetCursorModels() []*ModelInfo {
return []*ModelInfo{
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
}
}
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
@@ -260,6 +274,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetKiloModels(),
GetAmazonQModels(),
GetCodeBuddyModels(),
GetCursorModels(),
}
for _, models := range allModels {
for _, m := range models {
@@ -462,6 +477,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.4-mini",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.4 mini",
Description: "OpenAI GPT-5.4 mini via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "claude-haiku-4.5",
Object: "model",
@@ -556,6 +584,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-3-pro-preview",
@@ -567,6 +596,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-3.1-pro-preview",
@@ -576,8 +606,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot",
DisplayName: "Gemini 3.1 Pro (Preview)",
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
ContextLength: 173000,
MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-3-flash-preview",
@@ -587,8 +618,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot",
DisplayName: "Gemini 3 Flash (Preview)",
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
ContextLength: 1048576,
ContextLength: 173000,
MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "grok-code-fast-1",

View File

@@ -0,0 +1,29 @@
package registry
import "testing"
func TestGitHubCopilotGeminiModelsAreChatOnly(t *testing.T) {
models := GetGitHubCopilotModels()
required := map[string]bool{
"gemini-2.5-pro": false,
"gemini-3-pro-preview": false,
"gemini-3.1-pro-preview": false,
"gemini-3-flash-preview": false,
}
for _, model := range models {
if _, ok := required[model.ID]; !ok {
continue
}
required[model.ID] = true
if len(model.SupportedEndpoints) != 1 || model.SupportedEndpoints[0] != "/chat/completions" {
t.Fatalf("model %q supported endpoints = %v, want [/chat/completions]", model.ID, model.SupportedEndpoints)
}
}
for modelID, found := range required {
if !found {
t.Fatalf("expected GitHub Copilot model %q in definitions", modelID)
}
}
}

View File

@@ -14,7 +14,9 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -46,8 +48,16 @@ func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Man
// Identifier returns the executor identifier.
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio).
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
// PrepareRequest prepares the HTTP request for execution.
func (e *AIStudioExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
@@ -66,6 +76,9 @@ func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.A
return nil, fmt.Errorf("aistudio executor: missing auth")
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
return nil, fmt.Errorf("aistudio executor: request URL is empty")
}
@@ -115,8 +128,8 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, false)
if err != nil {
@@ -130,6 +143,11 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload,
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -137,7 +155,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
@@ -151,17 +169,17 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
if len(wsResp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
helps.AppendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
}
if wsResp.Status < 200 || wsResp.Status >= 300 {
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
}
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
reporter.Publish(ctx, helps.ParseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
@@ -174,8 +192,8 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, true)
if err != nil {
@@ -189,13 +207,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload,
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
@@ -208,24 +231,24 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
})
wsStream, err := e.relay.Stream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
firstEvent, ok := <-wsStream
if !ok {
err = fmt.Errorf("wsrelay: stream closed before start")
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK {
metadataLogged := false
if firstEvent.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
metadataLogged = true
}
var body bytes.Buffer
if len(firstEvent.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
helps.AppendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
body.Write(firstEvent.Payload)
}
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
@@ -233,18 +256,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
for event := range wsStream {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
if body.Len() == 0 {
body.WriteString(event.Err.Error())
}
break
}
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
body.Write(event.Payload)
}
if event.Type == wsrelay.MessageTypeStreamEnd {
@@ -260,23 +283,23 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
metadataLogged := false
processEvent := func(event wsrelay.StreamEvent) bool {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false
}
switch event.Type {
case wsrelay.MessageTypeStreamStart:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
filtered := FilterSSEUsageMetadata(event.Payload)
if detail, ok := parseGeminiStreamUsage(filtered); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
filtered := helps.FilterSSEUsageMetadata(event.Payload)
if detail, ok := helps.ParseGeminiStreamUsage(filtered); ok {
reporter.Publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param)
for i := range lines {
@@ -288,21 +311,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return false
case wsrelay.MessageTypeHTTPResp:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
}
reporter.publish(ctx, parseGeminiUsage(event.Payload))
reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
return false
case wsrelay.MessageTypeError:
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false
}
@@ -345,7 +368,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
@@ -358,12 +381,12 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
})
resp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
if len(resp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
helps.AppendAPIResponseChunk(ctx, e.cfg, resp.Body)
}
if resp.Status < 200 || resp.Status >= 300 {
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
@@ -404,8 +427,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
return nil, translatedPayload{}, err
}
payload = fixGeminiImageAspectRatio(baseModel, payload)
requestedModel := payloadRequestedModel(opts, req.Model)
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")

View File

@@ -24,6 +24,7 @@ import (
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -47,12 +48,41 @@ const (
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
antigravityCreditsRetryTTL = 5 * time.Hour
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
)
type antigravity429Category string
const (
antigravity429Unknown antigravity429Category = "unknown"
antigravity429RateLimited antigravity429Category = "rate_limited"
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
)
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
antigravityCreditsExhaustedByAuth sync.Map
antigravityPreferCreditsByModel sync.Map
antigravityQuotaExhaustedKeywords = []string{
"quota_exhausted",
"quota exhausted",
}
antigravityCreditsExhaustedKeywords = []string{
"google_one_ai",
"insufficient credit",
"insufficient credits",
"not enough credit",
"not enough credits",
"credit exhausted",
"credits exhausted",
"credit balance",
"minimumcreditamountforusage",
"minimum credit amount for usage",
"minimum credit",
"resource has been exhausted",
}
)
// AntigravityExecutor proxies requests to the antigravity upstream.
@@ -113,7 +143,7 @@ func initAntigravityTransport() {
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
antigravityTransportOnce.Do(initAntigravityTransport)
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
client := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
// If no transport is set, use the shared HTTP/1.1 transport.
if client.Transport == nil {
client.Transport = antigravityTransport
@@ -183,6 +213,231 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
return httpClient.Do(httpReq)
}
func injectEnabledCreditTypes(payload []byte) []byte {
if len(payload) == 0 {
return nil
}
if !gjson.ValidBytes(payload) {
return nil
}
updated, err := sjson.SetRawBytes(payload, "enabledCreditTypes", []byte(`["GOOGLE_ONE_AI"]`))
if err != nil {
return nil
}
return updated
}
func classifyAntigravity429(body []byte) antigravity429Category {
if len(body) == 0 {
return antigravity429Unknown
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityQuotaExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
return antigravity429QuotaExhausted
}
}
status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String())
if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") {
return antigravity429Unknown
}
details := gjson.GetBytes(body, "error.details")
if !details.Exists() || !details.IsArray() {
return antigravity429Unknown
}
for _, detail := range details.Array() {
if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" {
continue
}
reason := strings.TrimSpace(detail.Get("reason").String())
if strings.EqualFold(reason, "QUOTA_EXHAUSTED") {
return antigravity429QuotaExhausted
}
if strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED") {
return antigravity429RateLimited
}
}
return antigravity429Unknown
}
func antigravityCreditsRetryEnabled(cfg *config.Config) bool {
return cfg != nil && cfg.QuotaExceeded.AntigravityCredits
}
func antigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) bool {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return false
}
value, ok := antigravityCreditsExhaustedByAuth.Load(auth.ID)
if !ok {
return false
}
until, ok := value.(time.Time)
if !ok || until.IsZero() {
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
return false
}
if !until.After(now) {
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
return false
}
return true
}
func markAntigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return
}
antigravityCreditsExhaustedByAuth.Store(auth.ID, now.Add(antigravityCreditsRetryTTL))
}
func clearAntigravityCreditsExhausted(auth *cliproxyauth.Auth) {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return
}
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
}
func antigravityPreferCreditsKey(auth *cliproxyauth.Auth, modelName string) string {
if auth == nil {
return ""
}
authID := strings.TrimSpace(auth.ID)
modelName = strings.TrimSpace(modelName)
if authID == "" || modelName == "" {
return ""
}
return authID + "|" + modelName
}
func antigravityShouldPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time) bool {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return false
}
value, ok := antigravityPreferCreditsByModel.Load(key)
if !ok {
return false
}
until, ok := value.(time.Time)
if !ok || until.IsZero() {
antigravityPreferCreditsByModel.Delete(key)
return false
}
if !until.After(now) {
antigravityPreferCreditsByModel.Delete(key)
return false
}
return true
}
func markAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time, retryAfter *time.Duration) {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return
}
until := now.Add(antigravityCreditsRetryTTL)
if retryAfter != nil && *retryAfter > 0 {
until = now.Add(*retryAfter)
}
antigravityPreferCreditsByModel.Store(key, until)
}
func clearAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string) {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return
}
antigravityPreferCreditsByModel.Delete(key)
}
func shouldMarkAntigravityCreditsExhausted(statusCode int, body []byte, reqErr error) bool {
if reqErr != nil || statusCode == 0 {
return false
}
if statusCode >= http.StatusInternalServerError || statusCode == http.StatusRequestTimeout {
return false
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityCreditsExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
return true
}
}
return false
}
func newAntigravityStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if statusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil {
err.retryAfter = retryAfter
}
}
return err
}
func (e *AntigravityExecutor) attemptCreditsFallback(
ctx context.Context,
auth *cliproxyauth.Auth,
httpClient *http.Client,
token string,
modelName string,
payload []byte,
stream bool,
alt string,
baseURL string,
originalBody []byte,
) (*http.Response, bool) {
if !antigravityCreditsRetryEnabled(e.cfg) {
return nil, false
}
if classifyAntigravity429(originalBody) != antigravity429QuotaExhausted {
return nil, false
}
now := time.Now()
if antigravityCreditsExhausted(auth, now) {
return nil, false
}
creditsPayload := injectEnabledCreditTypes(payload)
if len(creditsPayload) == 0 {
return nil, false
}
httpReq, errReq := e.buildRequest(ctx, auth, token, modelName, creditsPayload, stream, alt, baseURL)
if errReq != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errReq)
return nil, true
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, true
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
retryAfter, _ := parseRetryDelay(originalBody)
markAntigravityPreferCredits(auth, modelName, now, retryAfter)
clearAntigravityCreditsExhausted(auth)
return httpResp, true
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close credits fallback response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, true
}
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, modelName)
markAntigravityCreditsExhausted(auth, now)
}
return nil, true
}
// Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
@@ -203,8 +458,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
@@ -222,8 +477,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -237,7 +492,15 @@ attemptLoop:
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, false, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
@@ -245,7 +508,7 @@ attemptLoop:
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
@@ -260,20 +523,50 @@ attemptLoop:
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, false, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
helps.RecordAPIResponseMetadata(ctx, e.cfg, creditsResp.StatusCode, creditsResp.Header.Clone())
creditsBody, errCreditsRead := io.ReadAll(creditsResp.Body)
if errClose := creditsResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close credits success response body error: %v", errClose)
}
if errCreditsRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errCreditsRead)
err = errCreditsRead
return resp, err
}
helps.AppendAPIResponseChunk(ctx, e.cfg, creditsBody)
reporter.Publish(ctx, helps.ParseAntigravityUsage(creditsBody))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, creditsBody, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: creditsResp.Header.Clone()}
reporter.EnsurePublished(ctx)
return resp, nil
}
}
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
@@ -295,33 +588,21 @@ attemptLoop:
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
return resp, err
}
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
return resp, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(lastStatus, lastBody)
case lastErr != nil:
err = lastErr
default:
@@ -345,8 +626,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
@@ -364,8 +645,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -379,7 +660,15 @@ attemptLoop:
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
@@ -387,7 +676,7 @@ attemptLoop:
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
@@ -401,14 +690,14 @@ attemptLoop:
err = errDo
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return resp, err
@@ -427,7 +716,24 @@ attemptLoop:
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
httpResp = creditsResp
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
}
}
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
goto streamSuccessClaudeNonStream
}
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
@@ -449,16 +755,11 @@ attemptLoop:
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
return resp, err
}
streamSuccessClaudeNonStream:
out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) {
defer close(out)
@@ -471,29 +772,29 @@ attemptLoop:
scanner.Buffer(nil, streamScannerBuffer)
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
line = helps.FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
payload := helps.JSONPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
reporter.Publish(ctx, detail)
}
out <- cliproxyexecutor.StreamChunk{Payload: payload}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
}
}(httpResp)
@@ -509,24 +810,18 @@ attemptLoop:
}
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
reporter.Publish(ctx, helps.ParseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
return resp, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(lastStatus, lastBody)
case lastErr != nil:
err = lastErr
default:
@@ -748,8 +1043,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
@@ -767,8 +1062,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -782,14 +1077,22 @@ attemptLoop:
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return nil, err
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil, errDo
}
@@ -803,14 +1106,14 @@ attemptLoop:
err = errDo
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return nil, err
@@ -829,7 +1132,24 @@ attemptLoop:
err = errRead
return nil, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
httpResp = creditsResp
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
}
}
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
goto streamSuccessExecuteStream
}
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
@@ -851,16 +1171,11 @@ attemptLoop:
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
return nil, err
}
streamSuccessExecuteStream:
out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) {
defer close(out)
@@ -874,19 +1189,19 @@ attemptLoop:
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
line = helps.FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
payload := helps.JSONPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
reporter.Publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param)
@@ -899,11 +1214,11 @@ attemptLoop:
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
}
}(httpResp)
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -911,13 +1226,7 @@ attemptLoop:
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(lastStatus, lastBody)
case lastErr != nil:
err = lastErr
default:
@@ -1011,8 +1320,13 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if host := resolveHost(base); host != "" {
httpReq.Host = host
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: requestURL.String(),
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -1026,7 +1340,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return cliproxyexecutor.Response{}, errDo
}
@@ -1040,16 +1354,16 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
return cliproxyexecutor.Response{}, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
@@ -1305,6 +1619,11 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
if host := resolveHost(base); host != "" {
httpReq.Host = host
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -1316,7 +1635,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
if e.cfg != nil && e.cfg.RequestLog {
payloadLog = []byte(payloadStr)
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: requestURL.String(),
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -1479,7 +1798,7 @@ func antigravityWait(ctx context.Context, wait time.Duration) error {
}
}
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
var antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
return []string{base}
}

View File

@@ -0,0 +1,423 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
func resetAntigravityCreditsRetryState() {
antigravityCreditsExhaustedByAuth = sync.Map{}
antigravityPreferCreditsByModel = sync.Map{}
}
func TestClassifyAntigravity429(t *testing.T) {
t.Run("quota exhausted", func(t *testing.T) {
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
}
})
t.Run("structured rate limit", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
if got := classifyAntigravity429(body); got != antigravity429RateLimited {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429RateLimited)
}
})
t.Run("structured quota exhausted", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "QUOTA_EXHAUSTED"}
]
}
}`)
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
}
})
t.Run("unknown", func(t *testing.T) {
body := []byte(`{"error":{"message":"too many requests"}}`)
if got := classifyAntigravity429(body); got != antigravity429Unknown {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429Unknown)
}
})
}
func TestInjectEnabledCreditTypes(t *testing.T) {
body := []byte(`{"model":"gemini-2.5-flash","request":{}}`)
got := injectEnabledCreditTypes(body)
if got == nil {
t.Fatal("injectEnabledCreditTypes() returned nil")
}
if !strings.Contains(string(got), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("injectEnabledCreditTypes() = %s, want enabledCreditTypes", string(got))
}
if got := injectEnabledCreditTypes([]byte(`not json`)); got != nil {
t.Fatalf("injectEnabledCreditTypes() for invalid json = %s, want nil", string(got))
}
}
func TestShouldMarkAntigravityCreditsExhausted(t *testing.T) {
for _, body := range [][]byte{
[]byte(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`),
[]byte(`{"error":{"message":"minimumCreditAmountForUsage requirement not met"}}`),
[]byte(`{"error":{"message":"Resource has been exhausted"}}`),
} {
if !shouldMarkAntigravityCreditsExhausted(http.StatusForbidden, body, nil) {
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body))
}
}
if shouldMarkAntigravityCreditsExhausted(http.StatusServiceUnavailable, []byte(`{"error":{"message":"credits exhausted"}}`), nil) {
t.Fatal("shouldMarkAntigravityCreditsExhausted() = true for 5xx, want false")
}
}
func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
requestBodies []string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
requestBodies = append(requestBodies, string(body))
reqNum := len(requestBodies)
mu.Unlock()
if reqNum == 1 {
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
return
}
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("second request body missing enabledCreditTypes: %s", string(body))
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-credits-ok",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("Execute() returned empty payload")
}
mu.Lock()
defer mu.Unlock()
if len(requestBodies) != 2 {
t.Fatalf("request count = %d, want 2", len(requestBodies))
}
}
func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var requestCount int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-credits-exhausted",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
markAntigravityCreditsExhausted(auth, time.Now())
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err == nil {
t.Fatal("Execute() error = nil, want 429")
}
sErr, ok := err.(statusErr)
if !ok {
t.Fatalf("Execute() error type = %T, want statusErr", err)
}
if got := sErr.StatusCode(); got != http.StatusTooManyRequests {
t.Fatalf("Execute() status code = %d, want %d", got, http.StatusTooManyRequests)
}
if requestCount != 1 {
t.Fatalf("request count = %d, want 1", requestCount)
}
}
func TestAntigravityExecute_PrefersCreditsAfterSuccessfulFallback(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
requestBodies []string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
requestBodies = append(requestBodies, string(body))
reqNum := len(requestBodies)
mu.Unlock()
switch reqNum {
case 1:
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"10s"}]}}`))
case 2, 3:
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("request %d body missing enabledCreditTypes: %s", reqNum, string(body))
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"OK"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
default:
t.Fatalf("unexpected request count %d", reqNum)
}
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-prefer-credits",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
request := cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}
opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatAntigravity}
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
t.Fatalf("first Execute() error = %v", err)
}
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
t.Fatalf("second Execute() error = %v", err)
}
mu.Lock()
defer mu.Unlock()
if len(requestBodies) != 3 {
t.Fatalf("request count = %d, want 3", len(requestBodies))
}
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("first request unexpectedly used credits: %s", requestBodies[0])
}
if !strings.Contains(requestBodies[1], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("fallback request missing credits: %s", requestBodies[1])
}
if !strings.Contains(requestBodies[2], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("preferred request missing credits: %s", requestBodies[2])
}
}
func TestAntigravityExecute_PreservesBaseURLFallbackAfterCreditsRetryFailure(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
firstCount int
secondCount int
)
firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
firstCount++
reqNum := firstCount
mu.Unlock()
switch reqNum {
case 1:
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"}]}}`))
case 2:
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("credits retry missing enabledCreditTypes: %s", string(body))
}
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":{"message":"permission denied"}}`))
default:
t.Fatalf("unexpected first server request count %d", reqNum)
}
}))
defer firstServer.Close()
secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
secondCount++
mu.Unlock()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
}))
defer secondServer.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-baseurl-fallback",
Attributes: map[string]string{
"base_url": firstServer.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
originalOrder := antigravityBaseURLFallbackOrder
defer func() { antigravityBaseURLFallbackOrder = originalOrder }()
antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
return []string{firstServer.URL, secondServer.URL}
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("Execute() returned empty payload")
}
if firstCount != 2 {
t.Fatalf("first server request count = %d, want 2", firstCount)
}
if secondCount != 1 {
t.Fatalf("second server request count = %d, want 1", secondCount)
}
}
func TestAntigravityExecute_DoesNotDirectInjectCreditsWhenFlagDisabled(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var requestBodies []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
requestBodies = append(requestBodies, string(body))
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: false},
})
auth := &cliproxyauth.Auth{
ID: "auth-flag-disabled",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
markAntigravityPreferCredits(auth, "gemini-2.5-flash", time.Now(), nil)
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err == nil {
t.Fatal("Execute() error = nil, want 429")
}
if len(requestBodies) != 1 {
t.Fatalf("request count = %d, want 1", len(requestBodies))
}
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("request unexpectedly used enabledCreditTypes with flag disabled: %s", requestBodies[0])
}
}

View File

@@ -6,7 +6,6 @@ import (
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
@@ -18,10 +17,13 @@ import (
"time"
"github.com/andybalholm/brotli"
"github.com/google/uuid"
"github.com/klauspost/compress/zstd"
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -44,6 +46,10 @@ type ClaudeExecutor struct {
// Previously "proxy_" was used but this is a detectable fingerprint difference.
const claudeToolPrefix = ""
// Anthropic-compatible upstreams may reject or even crash when Claude models
// omit max_tokens. Prefer registered model metadata before using a fallback.
const defaultModelMaxTokens = 1024
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
func (e *ClaudeExecutor) Identifier() string { return "claude" }
@@ -86,7 +92,7 @@ func (e *ClaudeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -101,8 +107,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
baseURL = "https://api.anthropic.com"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
@@ -125,8 +131,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body = ensureModelMaxTokens(body, baseModel)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
@@ -153,6 +160,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
if experimentalCCHSigningEnabled(e.cfg, auth) {
bodyForUpstream = signAnthropicMessagesBody(bodyForUpstream)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
@@ -166,7 +176,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -178,33 +188,33 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
@@ -213,7 +223,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -226,19 +236,19 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}()
data, err := io.ReadAll(decodedBody)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if stream {
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
}
} else {
reporter.publish(ctx, parseClaudeUsage(data))
reporter.Publish(ctx, helps.ParseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
@@ -269,8 +279,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
baseURL = "https://api.anthropic.com"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
originalPayloadSource := req.Payload
@@ -291,8 +301,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body = ensureModelMaxTokens(body, baseModel)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
@@ -316,6 +327,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
if experimentalCCHSigningEnabled(e.cfg, auth) {
bodyForUpstream = signAnthropicMessagesBody(bodyForUpstream)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
@@ -329,7 +343,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -341,33 +355,33 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -376,7 +390,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -397,9 +411,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
scanner.Buffer(nil, 52_428_800) // 50MB
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
@@ -411,8 +425,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Payload: cloned}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
return
@@ -424,9 +438,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
@@ -446,8 +460,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -496,7 +510,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -508,32 +522,32 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
resp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -541,7 +555,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}
decodedBody, err := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -554,10 +568,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}()
data, err := io.ReadAll(decodedBody)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
@@ -793,13 +807,13 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
}
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
var deviceProfile claudeDeviceProfile
stabilizeDeviceProfile := helps.ClaudeDeviceProfileStabilizationEnabled(cfg)
var deviceProfile helps.ClaudeDeviceProfile
if stabilizeDeviceProfile {
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
deviceProfile = helps.ResolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
}
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05,structured-outputs-2025-12-15,fast-mode-2026-02-01,redact-thinking-2026-02-12,token-efficient-tools-2026-03-28"
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
baseBetas = val
if !strings.Contains(val, "oauth") {
@@ -837,13 +851,22 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
r.Header.Set("Anthropic-Beta", baseBetas)
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
// Only set browser access header for API key mode; real Claude Code CLI does not send it.
if useAPIKey {
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
}
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
// Session ID: stable per auth/apiKey, matches Claude Code's X-Claude-Code-Session-Id header.
misc.EnsureHeader(r.Header, ginHeaders, "X-Claude-Code-Session-Id", helps.CachedSessionID(apiKey))
// Per-request UUID, matches Claude Code's x-client-request-id for first-party API.
if isAnthropicBase {
misc.EnsureHeader(r.Header, ginHeaders, "x-client-request-id", uuid.New().String())
}
r.Header.Set("Connection", "keep-alive")
if stream {
r.Header.Set("Accept", "text/event-stream")
@@ -858,16 +881,16 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
// Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
// to the configured baseline while still allowing newer official
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
if stabilizeDeviceProfile {
helps.ApplyClaudeDeviceProfileHeaders(r, deviceProfile)
} else {
helps.ApplyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(r, attrs)
if stabilizeDeviceProfile {
applyClaudeDeviceProfileHeaders(r, deviceProfile)
} else {
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
}
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
// may override it with a user-configured value. Compressed SSE breaks the line
// scanner regardless of user preference, so this is non-negotiable for streams.
@@ -893,7 +916,7 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
}
func checkSystemInstructions(payload []byte) []byte {
return checkSystemInstructionsWithMode(payload, false)
return checkSystemInstructionsWithSigningMode(payload, false, false, "2.1.63", "", "")
}
func isClaudeOAuthToken(apiKey string) bool {
@@ -1037,7 +1060,7 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
if prefix == "" {
return line
}
payload := jsonPayload(line)
payload := helps.JSONPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
}
@@ -1088,6 +1111,38 @@ func getClientUserAgent(ctx context.Context) string {
return ""
}
// parseEntrypointFromUA extracts the entrypoint from a Claude Code User-Agent.
// Format: "claude-cli/x.y.z (external, cli)" → "cli"
// Format: "claude-cli/x.y.z (external, vscode)" → "vscode"
// Returns "cli" if parsing fails or UA is not Claude Code.
func parseEntrypointFromUA(userAgent string) string {
// Find content inside parentheses
start := strings.Index(userAgent, "(")
end := strings.LastIndex(userAgent, ")")
if start < 0 || end <= start {
return "cli"
}
inner := userAgent[start+1 : end]
// Split by comma, take the second part (entrypoint is at index 1, after USER_TYPE)
// Format: "(USER_TYPE, ENTRYPOINT[, extra...])"
parts := strings.Split(inner, ",")
if len(parts) >= 2 {
ep := strings.TrimSpace(parts[1])
if ep != "" {
return ep
}
}
return "cli"
}
// getWorkloadFromContext extracts workload identifier from the gin request headers.
func getWorkloadFromContext(ctx context.Context) string {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
return strings.TrimSpace(ginCtx.GetHeader("X-CPA-Claude-Workload"))
}
return ""
}
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
@@ -1115,43 +1170,14 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bo
return cloakMode, strictMode, sensitiveWords, cacheUserID
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
if cfg == nil || auth == nil {
return nil
}
apiKey, baseURL := claudeCreds(auth)
if apiKey == "" {
return nil
}
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
// Match by API key
if strings.EqualFold(cfgKey, apiKey) {
// If baseURL is specified, also check it
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
continue
}
return entry.Cloak
}
}
return nil
}
// injectFakeUserID generates and injects a fake user ID into the request metadata.
// When useCache is false, a new user ID is generated for every call.
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
generateID := func() string {
if useCache {
return cachedUserID(apiKey)
return helps.CachedUserID(apiKey)
}
return generateFakeUserID()
return helps.GenerateFakeUserID()
}
metadata := gjson.GetBytes(payload, "metadata")
@@ -1161,38 +1187,84 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
}
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
if existingUserID == "" || !isValidUserID(existingUserID) {
if existingUserID == "" || !helps.IsValidUserID(existingUserID) {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
}
return payload
}
// generateBillingHeader creates the x-anthropic-billing-header text block that
// real Claude Code prepends to every system prompt array.
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=cli; cch=<hash>;
func generateBillingHeader(payload []byte) string {
// Generate a deterministic cch hash from the payload content (system + messages + tools).
// Real Claude Code uses a 5-char hex hash that varies per request.
h := sha256.Sum256(payload)
cch := hex.EncodeToString(h[:])[:5]
// fingerprintSalt is the salt used by Claude Code to compute the 3-char build fingerprint.
const fingerprintSalt = "59cf53e54c78"
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
buildBytes := make([]byte, 2)
_, _ = rand.Read(buildBytes)
buildHash := hex.EncodeToString(buildBytes)[:3]
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
// computeFingerprint computes the 3-char build fingerprint that Claude Code embeds in cc_version.
// Algorithm: SHA256(salt + messageText[4] + messageText[7] + messageText[20] + version)[:3]
func computeFingerprint(messageText, version string) string {
indices := [3]int{4, 7, 20}
runes := []rune(messageText)
var sb strings.Builder
for _, idx := range indices {
if idx < len(runes) {
sb.WriteRune(runes[idx])
} else {
sb.WriteRune('0')
}
}
input := fingerprintSalt + sb.String() + version
h := sha256.Sum256([]byte(input))
return hex.EncodeToString(h[:])[:3]
}
// checkSystemInstructionsWithMode injects Claude Code-style system blocks:
// generateBillingHeader creates the x-anthropic-billing-header text block that
// real Claude Code prepends to every system prompt array.
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=<ep>; cch=<hash>; [cc_workload=<wl>;]
func generateBillingHeader(payload []byte, experimentalCCHSigning bool, version, messageText, entrypoint, workload string) string {
if entrypoint == "" {
entrypoint = "cli"
}
buildHash := computeFingerprint(messageText, version)
workloadPart := ""
if workload != "" {
workloadPart = fmt.Sprintf(" cc_workload=%s;", workload)
}
if experimentalCCHSigning {
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=%s; cch=00000;%s", version, buildHash, entrypoint, workloadPart)
}
// Generate a deterministic cch hash from the payload content (system + messages + tools).
h := sha256.Sum256(payload)
cch := hex.EncodeToString(h[:])[:5]
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=%s; cch=%s;%s", version, buildHash, entrypoint, cch, workloadPart)
}
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
return checkSystemInstructionsWithSigningMode(payload, strictMode, false, "2.1.63", "", "")
}
// checkSystemInstructionsWithSigningMode injects Claude Code-style system blocks:
//
// system[0]: billing header (no cache_control)
// system[1]: agent identifier (no cache_control)
// system[2..]: user system messages (cache_control added when missing)
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
func checkSystemInstructionsWithSigningMode(payload []byte, strictMode bool, experimentalCCHSigning bool, version, entrypoint, workload string) []byte {
system := gjson.GetBytes(payload, "system")
billingText := generateBillingHeader(payload)
// Extract original message text for fingerprint computation (before billing injection).
// Use the first system text block's content as the fingerprint source.
messageText := ""
if system.IsArray() {
system.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
messageText = part.Get("text").String()
return false
}
return true
})
} else if system.Type == gjson.String {
messageText = system.String()
}
billingText := generateBillingHeader(payload, experimentalCCHSigning, version, messageText, entrypoint, workload)
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
// No cache_control on the agent block. It is a cloaking artifact with zero cache
// value (the last system block is what actually triggers caching of all system content).
@@ -1247,51 +1319,44 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
clientUserAgent := getClientUserAgent(ctx)
useExperimentalCCHSigning := experimentalCCHSigningEnabled(cfg, auth)
// Get cloak config from ClaudeKey configuration
cloakCfg := resolveClaudeKeyCloakConfig(cfg, auth)
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
// Determine cloak settings
var cloakMode string
var strictMode bool
var sensitiveWords []string
var cacheUserID bool
cloakMode := attrMode
strictMode := attrStrict
sensitiveWords := attrWords
cacheUserID := attrCache
if cloakCfg != nil {
cloakMode = cloakCfg.Mode
strictMode = cloakCfg.StrictMode
sensitiveWords = cloakCfg.SensitiveWords
if mode := strings.TrimSpace(cloakCfg.Mode); mode != "" {
cloakMode = mode
}
if cloakCfg.StrictMode {
strictMode = true
}
if len(cloakCfg.SensitiveWords) > 0 {
sensitiveWords = cloakCfg.SensitiveWords
}
if cloakCfg.CacheUserID != nil {
cacheUserID = *cloakCfg.CacheUserID
}
}
// Fallback to auth attributes if no config found
if cloakMode == "" {
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
cloakMode = attrMode
if !strictMode {
strictMode = attrStrict
}
if len(sensitiveWords) == 0 {
sensitiveWords = attrWords
}
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
cacheUserID = attrCache
}
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
_, _, _, attrCache := getCloakConfigFromAuth(auth)
cacheUserID = attrCache
}
// Determine if cloaking should be applied
if !shouldCloak(cloakMode, clientUserAgent) {
if !helps.ShouldCloak(cloakMode, clientUserAgent) {
return payload
}
// Skip system instructions for claude-3-5-haiku models
if !strings.HasPrefix(model, "claude-3-5-haiku") {
payload = checkSystemInstructionsWithMode(payload, strictMode)
billingVersion := helps.DefaultClaudeVersion(cfg)
entrypoint := parseEntrypointFromUA(clientUserAgent)
workload := getWorkloadFromContext(ctx)
payload = checkSystemInstructionsWithSigningMode(payload, strictMode, useExperimentalCCHSigning, billingVersion, entrypoint, workload)
}
// Inject fake user ID
@@ -1299,8 +1364,8 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
// Apply sensitive word obfuscation
if len(sensitiveWords) > 0 {
matcher := buildSensitiveWordMatcher(sensitiveWords)
payload = obfuscateSensitiveWords(payload, matcher)
matcher := helps.BuildSensitiveWordMatcher(sensitiveWords)
payload = helps.ObfuscateSensitiveWords(payload, matcher)
}
return payload
@@ -1310,7 +1375,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
// According to Anthropic's documentation, cache prefixes are created in order: tools -> system -> messages.
// This function adds cache_control to:
// 1. The LAST tool in the tools array (caches all tool definitions)
// 2. The LAST element in the system array (caches system prompt)
// 2. The LAST system prompt element
// 3. The SECOND-TO-LAST user turn (caches conversation history for multi-turn)
//
// Up to 4 cache breakpoints are allowed per request. Tools, System, and Messages are INDEPENDENT breakpoints.
@@ -1880,3 +1945,26 @@ func injectSystemCacheControl(payload []byte) []byte {
return payload
}
func ensureModelMaxTokens(body []byte, modelID string) []byte {
if len(body) == 0 || !gjson.ValidBytes(body) {
return body
}
if maxTokens := gjson.GetBytes(body, "max_tokens"); maxTokens.Exists() {
return body
}
for _, provider := range registry.GetGlobalRegistry().GetModelProviders(strings.TrimSpace(modelID)) {
if strings.EqualFold(provider, "claude") {
maxTokens := defaultModelMaxTokens
if info := registry.GetGlobalRegistry().GetModelInfo(strings.TrimSpace(modelID), "claude"); info != nil && info.MaxCompletionTokens > 0 {
maxTokens = info.MaxCompletionTokens
}
body, _ = sjson.SetBytes(body, "max_tokens", maxTokens)
return body
}
}
return body
}

View File

@@ -4,9 +4,11 @@ import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"sync"
"testing"
@@ -14,7 +16,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
xxHash64 "github.com/pierrec/xxHash/xxHash64"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -23,9 +28,7 @@ import (
)
func resetClaudeDeviceProfileCache() {
claudeDeviceProfileCacheMu.Lock()
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu.Unlock()
helps.ResetClaudeDeviceProfileCache()
}
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
@@ -98,7 +101,7 @@ func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
req := newClaudeHeaderTestRequest(t, incoming)
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
assertClaudeFingerprint(t, req.Header, "evil-client/9.9", "9.9.9", "v24.5.0", "Linux", "x64")
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
}
@@ -338,7 +341,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
var pauseOnce sync.Once
var releaseOnce sync.Once
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
helps.ClaudeDeviceProfileBeforeCandidateStore = func(candidate helps.ClaudeDeviceProfile) {
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
return
}
@@ -346,13 +349,13 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
<-releaseLow
}
t.Cleanup(func() {
claudeDeviceProfileBeforeCandidateStore = nil
helps.ClaudeDeviceProfileBeforeCandidateStore = nil
releaseOnce.Do(func() { close(releaseLow) })
})
lowResultCh := make(chan claudeDeviceProfile, 1)
lowResultCh := make(chan helps.ClaudeDeviceProfile, 1)
go func() {
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
lowResultCh <- helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
@@ -367,7 +370,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
t.Fatal("timed out waiting for lower candidate to pause before storing")
}
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
highResult := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.75.0"},
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
@@ -398,7 +401,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
}
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
cached := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"curl/8.7.1"},
}, cfg)
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
@@ -564,7 +567,7 @@ func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *tes
})
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
}
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
@@ -591,14 +594,14 @@ func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallbac
})
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
}
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
if claudeDeviceProfileStabilizationEnabled(nil) {
if helps.ClaudeDeviceProfileStabilizationEnabled(nil) {
t.Fatal("expected nil config to default to disabled stabilization")
}
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
if helps.ClaudeDeviceProfileStabilizationEnabled(&config.Config{}) {
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
}
}
@@ -796,8 +799,6 @@ func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
}
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
resetUserIDCache()
var userIDs []string
var requestModels []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -857,15 +858,13 @@ func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
if userIDs[0] != userIDs[1] {
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
}
if !isValidUserID(userIDs[0]) {
if !helps.IsValidUserID(userIDs[0]) {
t.Fatalf("user_id %q is not valid", userIDs[0])
}
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
}
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
resetUserIDCache()
var userIDs []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
@@ -903,7 +902,7 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
if userIDs[0] == userIDs[1] {
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
}
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
if !helps.IsValidUserID(userIDs[0]) || !helps.IsValidUserID(userIDs[1]) {
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
}
}
@@ -1183,6 +1182,83 @@ func testClaudeExecutorInvalidCompressedErrorBody(
}
}
func TestEnsureModelMaxTokens_UsesRegisteredMaxCompletionTokens(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-max-completion-tokens-client"
modelID := "test-claude-max-completion-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
MaxCompletionTokens: 4096,
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-max-completion-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 4096 {
t.Fatalf("max_tokens = %d, want %d", got, 4096)
}
}
func TestEnsureModelMaxTokens_DefaultsMissingValue(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-default-max-tokens-client"
modelID := "test-claude-default-max-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-default-max-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != defaultModelMaxTokens {
t.Fatalf("max_tokens = %d, want %d", got, defaultModelMaxTokens)
}
}
func TestEnsureModelMaxTokens_PreservesExplicitValue(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-preserve-max-tokens-client"
modelID := "test-claude-preserve-max-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
MaxCompletionTokens: 4096,
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-preserve-max-tokens-model","max_tokens":2048,"messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 2048 {
t.Fatalf("max_tokens = %d, want %d", got, 2048)
}
}
func TestEnsureModelMaxTokens_SkipsUnregisteredModel(t *testing.T) {
input := []byte(`{"model":"test-claude-unregistered-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, "test-claude-unregistered-model")
if gjson.GetBytes(out, "max_tokens").Exists() {
t.Fatalf("max_tokens should remain unset, got %s", gjson.GetBytes(out, "max_tokens").Raw)
}
}
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
// requests use Accept-Encoding: identity so the upstream cannot respond with a
// compressed SSE body that would silently break the line scanner.
@@ -1340,6 +1416,35 @@ func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
}
}
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
// detects zstd-compressed content via magic bytes even when Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
_, _ = enc.Write([]byte(plaintext))
_ = enc.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
// plain text untouched when Content-Encoding is absent and no magic bytes match.
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
@@ -1411,77 +1516,6 @@ func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T)
}
}
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
// path's enforced identity encoding.
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
var gotEncoding string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
// Inject Accept-Encoding via the custom header attribute mechanism.
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
}
}
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
// Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
_, _ = enc.Write([]byte(plaintext))
_ = enc.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
// the Content-Encoding header. This closes the gap left by PR #1771, which only
@@ -1565,6 +1599,45 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
}
}
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies that the
// streaming executor enforces Accept-Encoding: identity regardless of auth.Attributes override.
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
var gotEncoding string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
}
}
// Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
@@ -1648,3 +1721,115 @@ func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
}
}
func TestClaudeExecutor_ExperimentalCCHSigningDisabledByDefaultKeepsLegacyHeader(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected request body to be captured")
}
billingHeader := gjson.GetBytes(seenBody, "system.0.text").String()
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
t.Fatalf("system.0.text = %q, want billing header", billingHeader)
}
if strings.Contains(billingHeader, "cch=00000;") {
t.Fatalf("legacy mode should not forward cch placeholder, got %q", billingHeader)
}
}
func TestClaudeExecutor_ExperimentalCCHSigningOptInSignsFinalBody(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{
ClaudeKey: []config.ClaudeKey{{
APIKey: "key-123",
BaseURL: server.URL,
ExperimentalCCHSigning: true,
}},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
const messageText = "please keep literal cch=00000 in this message"
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"please keep literal cch=00000 in this message"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected request body to be captured")
}
if got := gjson.GetBytes(seenBody, "messages.0.content.0.text").String(); got != messageText {
t.Fatalf("message text = %q, want %q", got, messageText)
}
billingPattern := regexp.MustCompile(`(x-anthropic-billing-header:[^"]*?\bcch=)([0-9a-f]{5})(;)`)
match := billingPattern.FindSubmatch(seenBody)
if match == nil {
t.Fatalf("expected signed billing header in body: %s", string(seenBody))
}
actualCCH := string(match[2])
unsignedBody := billingPattern.ReplaceAll(seenBody, []byte(`${1}00000${3}`))
wantCCH := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, 0x6E52736AC806831E)&0xFFFFF)
if actualCCH != wantCCH {
t.Fatalf("cch = %q, want %q\nbody: %s", actualCCH, wantCCH, string(seenBody))
}
}
func TestApplyCloaking_PreservesConfiguredStrictModeAndSensitiveWordsWhenModeOmitted(t *testing.T) {
cfg := &config.Config{
ClaudeKey: []config.ClaudeKey{{
APIKey: "key-123",
Cloak: &config.CloakConfig{
StrictMode: true,
SensitiveWords: []string{"proxy"},
},
}},
}
auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "key-123"}}
payload := []byte(`{"system":"proxy rules","messages":[{"role":"user","content":[{"type":"text","text":"proxy access"}]}]}`)
out := applyCloaking(context.Background(), cfg, auth, payload, "claude-3-5-sonnet-20241022", "key-123")
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("expected strict mode to keep only injected system blocks, got %d", len(blocks))
}
if got := gjson.GetBytes(out, "messages.0.content.0.text").String(); !strings.Contains(got, "\u200B") {
t.Fatalf("expected configured sensitive word obfuscation to apply, got %q", got)
}
}

View File

@@ -0,0 +1,81 @@
package executor
import (
"fmt"
"regexp"
"strings"
xxHash64 "github.com/pierrec/xxHash/xxHash64"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const claudeCCHSeed uint64 = 0x6E52736AC806831E
var claudeBillingHeaderCCHPattern = regexp.MustCompile(`\bcch=([0-9a-f]{5});`)
func signAnthropicMessagesBody(body []byte) []byte {
billingHeader := gjson.GetBytes(body, "system.0.text").String()
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
return body
}
if !claudeBillingHeaderCCHPattern.MatchString(billingHeader) {
return body
}
unsignedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(billingHeader, "cch=00000;")
unsignedBody, err := sjson.SetBytes(body, "system.0.text", unsignedBillingHeader)
if err != nil {
return body
}
cch := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, claudeCCHSeed)&0xFFFFF)
signedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(unsignedBillingHeader, "cch="+cch+";")
signedBody, err := sjson.SetBytes(unsignedBody, "system.0.text", signedBillingHeader)
if err != nil {
return unsignedBody
}
return signedBody
}
func resolveClaudeKeyConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.ClaudeKey {
if cfg == nil || auth == nil {
return nil
}
apiKey, baseURL := claudeCreds(auth)
if apiKey == "" {
return nil
}
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if !strings.EqualFold(cfgKey, apiKey) {
continue
}
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
continue
}
return entry
}
return nil
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
entry := resolveClaudeKeyConfig(cfg, auth)
if entry == nil {
return nil
}
return entry.Cloak
}
func experimentalCCHSigningEnabled(cfg *config.Config, auth *cliproxyauth.Auth) bool {
entry := resolveClaudeKeyConfig(cfg, auth)
return entry != nil && entry.ExperimentalCCHSigning
}

View File

@@ -13,6 +13,7 @@ import (
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -28,8 +29,8 @@ import (
)
const (
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexOriginator = "codex_cli_rs"
codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
codexOriginator = "codex-tui"
)
var dataTag = []byte("data:")
@@ -73,7 +74,7 @@ func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -88,8 +89,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -106,16 +107,15 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
body, _ = sjson.DeleteBytes(body, "stream_options")
body = normalizeCodexInstructions(body)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -129,7 +129,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -140,10 +140,10 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -151,20 +151,20 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
@@ -177,8 +177,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
continue
}
if detail, ok := parseCodexUsage(line); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseCodexUsage(line); ok {
reporter.Publish(ctx, detail)
}
var param any
@@ -198,8 +198,8 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai-response")
@@ -216,10 +216,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "stream")
body = normalizeCodexInstructions(body)
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -233,7 +234,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -244,10 +245,10 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -255,22 +256,22 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
reporter.ensurePublished(ctx)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
reporter.EnsurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -288,8 +289,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -306,15 +307,14 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "model", baseModel)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
body = normalizeCodexInstructions(body)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -328,7 +328,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -340,24 +340,24 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, readErr := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
return nil, readErr
}
appendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = newCodexStatusErr(httpResp.StatusCode, data)
return nil, err
}
@@ -374,13 +374,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if bytes.HasPrefix(line, dataTag) {
data := bytes.TrimSpace(line[5:])
if gjson.GetBytes(data, "type").String() == "response.completed" {
if detail, ok := parseCodexUsage(data); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseCodexUsage(data); ok {
reporter.Publish(ctx, detail)
}
}
}
@@ -391,8 +391,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -415,10 +415,9 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "stream", false)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
body = normalizeCodexInstructions(body)
enc, err := tokenizerForCodexModel(baseModel)
if err != nil {
@@ -597,18 +596,18 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
}
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
var cache codexCache
var cache helps.CodexCache
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
var ok bool
if cache, ok = getCodexCache(key); !ok {
cache = codexCache{
if cache, ok = helps.GetCodexCache(key); !ok {
cache = helps.CodexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
setCodexCache(key, cache)
helps.SetCodexCache(key, cache)
}
}
} else if from == "openai-response" {
@@ -617,7 +616,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
cache.ID = promptCacheKey.String()
}
} else if from == "openai" {
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
if apiKey := strings.TrimSpace(helps.APIKeyFromContext(ctx)); apiKey != "" {
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
}
}
@@ -630,7 +629,6 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
return nil, err
}
if cache.ID != "" {
httpReq.Header.Set("Conversation_id", cache.ID)
httpReq.Header.Set("Session_id", cache.ID)
}
return httpReq, nil
@@ -645,13 +643,19 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
ginHeaders = ginCtx.Request.Header
}
if ginHeaders.Get("X-Codex-Beta-Features") != "" {
r.Header.Set("X-Codex-Beta-Features", ginHeaders.Get("X-Codex-Beta-Features"))
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if strings.Contains(r.Header.Get("User-Agent"), "Mac OS") {
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
}
if stream {
r.Header.Set("Accept", "text/event-stream")
} else {
@@ -685,13 +689,47 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
}
func newCodexStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
errCode := statusCode
if isCodexModelCapacityError(body) {
errCode = http.StatusTooManyRequests
}
err := statusErr{code: errCode, msg: string(body)}
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
err.retryAfter = retryAfter
}
return err
}
func normalizeCodexInstructions(body []byte) []byte {
instructions := gjson.GetBytes(body, "instructions")
if !instructions.Exists() || instructions.Type == gjson.Null {
body, _ = sjson.SetBytes(body, "instructions", "")
}
return body
}
func isCodexModelCapacityError(errorBody []byte) bool {
if len(errorBody) == 0 {
return false
}
candidates := []string{
gjson.GetBytes(errorBody, "error.message").String(),
gjson.GetBytes(errorBody, "message").String(),
string(errorBody),
}
for _, candidate := range candidates {
lower := strings.ToLower(strings.TrimSpace(candidate))
if lower == "" {
continue
}
if strings.Contains(lower, "selected model is at capacity") ||
strings.Contains(lower, "model is at capacity. please try a different model") {
return true
}
}
return false
}
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
return nil

View File

@@ -42,8 +42,8 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
if gotKey != expectedKey {
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
}
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != "" {
t.Fatalf("Conversation_id = %q, want empty", gotConversation)
}
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)

View File

@@ -0,0 +1,79 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorCompactAddsDefaultInstructions(t *testing.T) {
cases := []struct {
name string
payload string
}{
{
name: "missing instructions",
payload: `{"model":"gpt-5.4","input":"hello"}`,
},
{
name: "null instructions",
payload: `{"model":"gpt-5.4","instructions":null,"input":"hello"}`,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}`))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(tc.payload),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Alt: "responses/compact",
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/responses/compact" {
t.Fatalf("path = %q, want %q", gotPath, "/responses/compact")
}
if !gjson.GetBytes(gotBody, "instructions").Exists() {
t.Fatalf("expected instructions in compact request body, got %s", string(gotBody))
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
if string(resp.Payload) != `{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}` {
t.Fatalf("payload = %s", string(resp.Payload))
}
})
}
}

View File

@@ -0,0 +1,123 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorExecuteNormalizesNullInstructions(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/responses" {
t.Fatalf("path = %q, want %q", gotPath, "/responses")
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
}
func TestCodexExecutorExecuteStreamNormalizesNullInstructions(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Stream: true,
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for range result.Chunks {
}
if gotPath != "/responses" {
t.Fatalf("path = %q, want %q", gotPath, "/responses")
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
}
func TestCodexExecutorCountTokensTreatsNullInstructionsAsEmpty(t *testing.T) {
executor := NewCodexExecutor(&config.Config{})
nullResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
})
if err != nil {
t.Fatalf("CountTokens(null) error: %v", err)
}
emptyResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":"","input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
})
if err != nil {
t.Fatalf("CountTokens(empty) error: %v", err)
}
if string(nullResp.Payload) != string(emptyResp.Payload) {
t.Fatalf("token count payload mismatch:\nnull=%s\nempty=%s", string(nullResp.Payload), string(emptyResp.Payload))
}
}

View File

@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
})
}
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
err := newCodexStatusErr(http.StatusBadRequest, body)
if got := err.StatusCode(); got != http.StatusTooManyRequests {
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
}
if err.RetryAfter() != nil {
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
}
}
func itoa(v int64) string {
return strconv.FormatInt(v, 10)
}

View File

@@ -15,10 +15,12 @@ import (
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -44,10 +46,18 @@ const (
type CodexWebsocketsExecutor struct {
*CodexExecutor
sessMu sync.Mutex
store *codexWebsocketSessionStore
}
type codexWebsocketSessionStore struct {
mu sync.Mutex
sessions map[string]*codexWebsocketSession
}
var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{
sessions: make(map[string]*codexWebsocketSession),
}
type codexWebsocketSession struct {
sessionID string
@@ -71,7 +81,7 @@ type codexWebsocketSession struct {
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
return &CodexWebsocketsExecutor{
CodexExecutor: NewCodexExecutor(cfg),
sessions: make(map[string]*codexWebsocketSession),
store: globalCodexWebsocketSessionStore,
}
}
@@ -155,8 +165,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -173,8 +183,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
@@ -209,7 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
wsReqLog := helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -219,16 +229,14 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
}
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if respHS != nil {
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
if respHS != nil {
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
}
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.Execute(ctx, auth, req, opts)
@@ -236,10 +244,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if respHS != nil && respHS.StatusCode > 0 {
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
}
recordAPIResponseError(ctx, e.cfg, errDial)
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
return resp, errDial
}
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
defer func() {
@@ -268,10 +276,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
// Retry once with a fresh websocket connection. This is mainly to handle
// upstream closing the socket between sequential requests within the same
// execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry == nil && connRetry != nil {
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -282,20 +290,22 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
AuthType: authType,
AuthValue: authValue,
})
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
conn = connRetry
wsReqBody = wsReqBodyRetry
} else {
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
recordAPIResponseError(ctx, e.cfg, errSendRetry)
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
return resp, errSendRetry
}
} else {
recordAPIResponseError(ctx, e.cfg, errDialRetry)
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
return resp, errDialRetry
}
} else {
recordAPIResponseError(ctx, e.cfg, errSend)
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
return resp, errSend
}
}
@@ -306,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
return resp, errRead
}
if msgType != websocket.TextMessage {
@@ -315,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
}
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
return resp, err
}
continue
@@ -325,21 +335,21 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if len(payload) == 0 {
continue
}
appendAPIResponseChunk(ctx, e.cfg, payload)
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok {
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
}
recordAPIResponseError(ctx, e.cfg, wsErr)
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
return resp, wsErr
}
payload = normalizeCodexWebsocketCompletion(payload)
eventType := gjson.GetBytes(payload, "type").String()
if eventType == "response.completed" {
if detail, ok := parseCodexUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseCodexUsage(payload); ok {
reporter.Publish(ctx, detail)
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, &param)
@@ -364,8 +374,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -376,8 +386,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
@@ -403,7 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
wsReqLog := helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -413,18 +423,18 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
}
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
var upstreamHeaders http.Header
if respHS != nil {
upstreamHeaders = respHS.Header.Clone()
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
if respHS != nil {
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
}
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
@@ -432,13 +442,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if respHS != nil && respHS.StatusCode > 0 {
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
}
recordAPIResponseError(ctx, e.cfg, errDial)
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
if sess != nil {
sess.reqMu.Unlock()
}
return nil, errDial
}
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
@@ -451,20 +461,21 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
recordAPIResponseError(ctx, e.cfg, errSend)
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
// Retry once with a new websocket connection for the same execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry != nil || connRetry == nil {
recordAPIResponseError(ctx, e.cfg, errDialRetry)
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
sess.clearActive(readCh)
sess.reqMu.Unlock()
return nil, errDialRetry
}
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -475,8 +486,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthType: authType,
AuthValue: authValue,
})
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
recordAPIResponseError(ctx, e.cfg, errSendRetry)
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
sess.clearActive(readCh)
sess.reqMu.Unlock()
@@ -542,8 +554,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
terminateReason = "read_error"
terminateErr = errRead
recordAPIResponseError(ctx, e.cfg, errRead)
reporter.publishFailure(ctx)
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
reporter.PublishFailure(ctx)
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
return
}
@@ -552,8 +564,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
err = fmt.Errorf("codex websockets executor: unexpected binary message")
terminateReason = "unexpected_binary"
terminateErr = err
recordAPIResponseError(ctx, e.cfg, err)
reporter.publishFailure(ctx)
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
reporter.PublishFailure(ctx)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
}
@@ -567,13 +579,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if len(payload) == 0 {
continue
}
appendAPIResponseChunk(ctx, e.cfg, payload)
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok {
terminateReason = "upstream_error"
terminateErr = wsErr
recordAPIResponseError(ctx, e.cfg, wsErr)
reporter.publishFailure(ctx)
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
reporter.PublishFailure(ctx)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
}
@@ -584,8 +596,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
payload = normalizeCodexWebsocketCompletion(payload)
eventType := gjson.GetBytes(payload, "type").String()
if eventType == "response.completed" || eventType == "response.done" {
if detail, ok := parseCodexUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseCodexUsage(payload); ok {
reporter.Publish(ctx, detail)
}
}
@@ -767,19 +779,19 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
return rawJSON, headers
}
var cache codexCache
var cache helps.CodexCache
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
if cached, ok := getCodexCache(key); ok {
if cached, ok := helps.GetCodexCache(key); ok {
cache = cached
} else {
cache = codexCache{
cache = helps.CodexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
setCodexCache(key, cache)
helps.SetCodexCache(key, cache)
}
}
} else if from == "openai-response" {
@@ -791,7 +803,6 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
headers.Set("Conversation_id", cache.ID)
headers.Set("Session_id", cache.ID)
}
return rawJSON, headers
@@ -806,11 +817,11 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
}
var ginHeaders http.Header
if ginCtx := ginContextFrom(ctx); ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header.Clone()
}
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
_, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
@@ -826,8 +837,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
betaHeader = codexResponsesWebsocketBetaHeaderValue
}
headers.Set("OpenAI-Beta", betaHeader)
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
}
headers.Del("User-Agent")
isAPIKey := false
if auth != nil && auth.Attributes != nil {
@@ -1011,6 +1024,32 @@ func encodeCodexWebsocketAsSSE(payload []byte) []byte {
return line
}
func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog {
upgradeInfo := info
upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL)
upgradeInfo.Method = http.MethodGet
upgradeInfo.Body = nil
upgradeInfo.Headers = info.Headers.Clone()
if upgradeInfo.Headers == nil {
upgradeInfo.Headers = make(http.Header)
}
if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" {
upgradeInfo.Headers.Set("Connection", "Upgrade")
}
if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" {
upgradeInfo.Headers.Set("Upgrade", "websocket")
}
return upgradeInfo
}
func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) {
if resp == nil {
return
}
helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone())
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
}
func websocketHandshakeBody(resp *http.Response) []byte {
if resp == nil || resp.Body == nil {
return nil
@@ -1055,16 +1094,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
if sessionID == "" {
return nil
}
e.sessMu.Lock()
defer e.sessMu.Unlock()
if e.sessions == nil {
e.sessions = make(map[string]*codexWebsocketSession)
if e == nil {
return nil
}
if sess, ok := e.sessions[sessionID]; ok && sess != nil {
store := e.store
if store == nil {
store = globalCodexWebsocketSessionStore
}
store.mu.Lock()
defer store.mu.Unlock()
if store.sessions == nil {
store.sessions = make(map[string]*codexWebsocketSession)
}
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
return sess
}
sess := &codexWebsocketSession{sessionID: sessionID}
e.sessions[sessionID] = sess
store.sessions[sessionID] = sess
return sess
}
@@ -1210,14 +1256,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
return
}
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
e.closeAllExecutionSessions("executor_replaced")
// Executor replacement can happen during hot reload (config/credential changes).
// Do not force-close upstream websocket sessions here, otherwise in-flight
// downstream websocket requests get interrupted.
return
}
e.sessMu.Lock()
sess := e.sessions[sessionID]
delete(e.sessions, sessionID)
e.sessMu.Unlock()
store := e.store
if store == nil {
store = globalCodexWebsocketSessionStore
}
store.mu.Lock()
sess := store.sessions[sessionID]
delete(store.sessions, sessionID)
store.mu.Unlock()
e.closeExecutionSession(sess, "session_closed")
}
@@ -1227,15 +1279,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
return
}
e.sessMu.Lock()
sessions := make([]*codexWebsocketSession, 0, len(e.sessions))
for sessionID, sess := range e.sessions {
delete(e.sessions, sessionID)
store := e.store
if store == nil {
store = globalCodexWebsocketSessionStore
}
store.mu.Lock()
sessions := make([]*codexWebsocketSession, 0, len(store.sessions))
for sessionID, sess := range store.sessions {
delete(store.sessions, sessionID)
if sess != nil {
sessions = append(sessions, sess)
}
}
e.sessMu.Unlock()
store.mu.Unlock()
for i := range sessions {
e.closeExecutionSession(sessions[i], reason)
@@ -1243,6 +1299,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
}
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
closeCodexWebsocketSession(sess, reason)
}
func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) {
if sess == nil {
return
}
@@ -1283,6 +1343,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
}
// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions
// associated with the supplied auth ID.
func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) {
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
reason = strings.TrimSpace(reason)
if reason == "" {
reason = "auth_removed"
}
store := globalCodexWebsocketSessionStore
if store == nil {
return
}
type sessionItem struct {
sessionID string
sess *codexWebsocketSession
}
store.mu.Lock()
items := make([]sessionItem, 0, len(store.sessions))
for sessionID, sess := range store.sessions {
items = append(items, sessionItem{sessionID: sessionID, sess: sess})
}
store.mu.Unlock()
matches := make([]sessionItem, 0)
for i := range items {
sess := items[i].sess
if sess == nil {
continue
}
sess.connMu.Lock()
sessAuthID := strings.TrimSpace(sess.authID)
sess.connMu.Unlock()
if sessAuthID == authID {
matches = append(matches, items[i])
}
}
if len(matches) == 0 {
return
}
toClose := make([]*codexWebsocketSession, 0, len(matches))
store.mu.Lock()
for i := range matches {
current, ok := store.sessions[matches[i].sessionID]
if !ok || current == nil || current != matches[i].sess {
continue
}
delete(store.sessions, matches[i].sessionID)
toClose = append(toClose, current)
}
store.mu.Unlock()
for i := range toClose {
closeCodexWebsocketSession(toClose[i], reason)
}
}
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
// 1. The downstream transport is websocket, and
// 2. The selected auth enables websockets.

View File

@@ -0,0 +1,48 @@
package executor
import (
"testing"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) {
sessionID := "test-session-store-survives-replace"
globalCodexWebsocketSessionStore.mu.Lock()
delete(globalCodexWebsocketSessionStore.sessions, sessionID)
globalCodexWebsocketSessionStore.mu.Unlock()
exec1 := NewCodexWebsocketsExecutor(nil)
sess1 := exec1.getOrCreateSession(sessionID)
if sess1 == nil {
t.Fatalf("expected session to be created")
}
exec2 := NewCodexWebsocketsExecutor(nil)
sess2 := exec2.getOrCreateSession(sessionID)
if sess2 == nil {
t.Fatalf("expected session to be available across executors")
}
if sess1 != sess2 {
t.Fatalf("expected the same session instance across executors")
}
exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID)
globalCodexWebsocketSessionStore.mu.Lock()
_, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID]
globalCodexWebsocketSessionStore.mu.Unlock()
if !stillPresent {
t.Fatalf("expected session to remain after executor replacement close marker")
}
exec2.CloseExecutionSession(sessionID)
globalCodexWebsocketSessionStore.mu.Lock()
_, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID]
globalCodexWebsocketSessionStore.mu.Unlock()
if presentAfterClose {
t.Fatalf("expected session to be removed after explicit close")
}
}

View File

@@ -38,8 +38,8 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
}
if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
@@ -97,8 +97,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
@@ -129,8 +129,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
if gotVal := got.Get("User-Agent"); gotVal != "" {
t.Fatalf("User-Agent = %s, want empty", gotVal)
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
@@ -155,8 +155,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
@@ -177,8 +177,8 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)

View File

@@ -0,0 +1,129 @@
package executor
import (
"context"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
}
func parseOpenAIUsage(data []byte) usage.Detail {
return helps.ParseOpenAIUsage(data)
}
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
return helps.ParseOpenAIStreamUsage(line)
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
return helps.ParseOpenAIUsage(data)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
return helps.ParseOpenAIStreamUsage(line)
}
func getTokenizer(model string) (tokenizer.Codec, error) {
return helps.TokenizerForModel(model)
}
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return helps.CountOpenAIChatTokens(enc, payload)
}
func countClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return helps.CountClaudeChatTokens(enc, payload)
}
func buildOpenAIUsageJSON(count int64) []byte {
return helps.BuildOpenAIUsageJSON(count)
}
type upstreamRequestLog = helps.UpstreamRequestLog
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
helps.RecordAPIRequest(ctx, cfg, info)
}
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
helps.RecordAPIResponseMetadata(ctx, cfg, status, headers)
}
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
helps.RecordAPIResponseError(ctx, cfg, err)
}
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
helps.AppendAPIResponseChunk(ctx, cfg, chunk)
}
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
return helps.PayloadRequestedModel(opts, fallback)
}
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
return helps.ApplyPayloadConfigWithRoot(cfg, model, protocol, root, payload, original, requestedModel)
}
func summarizeErrorBody(contentType string, body []byte) string {
return helps.SummarizeErrorBody(contentType, body)
}
func apiKeyFromContext(ctx context.Context) string {
return helps.APIKeyFromContext(ctx)
}
func tokenizerForModel(model string) (tokenizer.Codec, error) {
return helps.TokenizerForModel(model)
}
func collectOpenAIContent(content gjson.Result, segments *[]string) {
helps.CollectOpenAIContent(content, segments)
}
type usageReporter struct {
reporter *helps.UsageReporter
}
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
return &usageReporter{reporter: helps.NewUsageReporter(ctx, provider, model, auth)}
}
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
if r == nil || r.reporter == nil {
return
}
r.reporter.Publish(ctx, detail)
}
func (r *usageReporter) publishFailure(ctx context.Context) {
if r == nil || r.reporter == nil {
return
}
r.reporter.PublishFailure(ctx)
}
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
if r == nil || r.reporter == nil {
return
}
r.reporter.TrackFailure(ctx, errPtr)
}
func (r *usageReporter) ensurePublished(ctx context.Context) {
if r == nil || r.reporter == nil {
return
}
r.reporter.EnsurePublished(ctx)
}

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -81,6 +82,11 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
}
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(req, "unknown")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
@@ -112,8 +118,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return resp, err
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
@@ -132,8 +138,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
requestedModel := payloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
action := "generateContent"
if req.Metadata != nil {
@@ -190,7 +196,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
@@ -204,7 +211,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
httpResp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
err = errDo
return resp, err
}
@@ -213,15 +220,15 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose)
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
reporter.publish(ctx, parseGeminiCLIUsage(data))
reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -230,7 +237,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 {
if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
@@ -245,7 +252,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
if len(lastBody) > 0 {
appendAPIResponseChunk(ctx, e.cfg, lastBody)
helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
}
if lastStatus == 0 {
lastStatus = 429
@@ -266,8 +273,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
return nil, err
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
@@ -286,8 +293,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
requestedModel := payloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
projectID := resolveGeminiProjectID(auth)
@@ -335,7 +342,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "text/event-stream")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
@@ -349,25 +357,25 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
httpResp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
err = errDo
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return nil, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 {
if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
@@ -394,9 +402,9 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiCLIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseGeminiCLIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param)
@@ -411,8 +419,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
return
@@ -420,13 +428,13 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
data, errRead := io.ReadAll(resp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errRead}
return
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param)
for i := range segments {
@@ -443,7 +451,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
if len(lastBody) > 0 {
appendAPIResponseChunk(ctx, e.cfg, lastBody)
helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
}
if lastStatus == 0 {
lastStatus = 429
@@ -516,7 +524,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, baseModel)
reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
@@ -530,17 +539,19 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
resp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
data, errRead := io.ReadAll(resp.Body)
_ = resp.Body.Close()
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if errClose := resp.Body.Close(); errClose != nil {
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
@@ -611,7 +622,7 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
}
ctxToken := ctx
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
}
@@ -707,7 +718,7 @@ func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any {
}
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
}
func cloneMap(in map[string]any) map[string]any {

View File

@@ -13,6 +13,7 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -85,7 +86,7 @@ func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -110,8 +111,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
// Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat
@@ -130,8 +131,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := "generateContent"
@@ -165,7 +166,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -177,10 +178,10 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -188,21 +189,21 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
log.Errorf("gemini executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -218,8 +219,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -237,8 +238,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
baseURL := resolveGeminiBaseURL(auth)
@@ -268,7 +269,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -280,17 +281,17 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini executor: close response body error: %v", errClose)
}
@@ -310,14 +311,14 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
filtered := FilterSSEUsageMetadata(line)
payload := jsonPayload(filtered)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
filtered := helps.FilterSSEUsageMetadata(line)
payload := helps.JSONPayload(filtered)
if len(payload) == 0 {
continue
}
if detail, ok := parseGeminiStreamUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseGeminiStreamUsage(payload); ok {
reporter.Publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param)
for i := range lines {
@@ -329,8 +330,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -381,7 +382,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -393,23 +394,27 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
resp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
defer func() { _ = resp.Body.Close() }()
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
data, err := io.ReadAll(resp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, helps.SummarizeErrorBody(resp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
}

View File

@@ -16,7 +16,9 @@ import (
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -227,7 +229,7 @@ func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -301,8 +303,8 @@ func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Aut
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
var body []byte
@@ -332,8 +334,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
}
@@ -362,6 +364,11 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
return resp, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -369,7 +376,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -381,10 +388,10 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
@@ -392,21 +399,21 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
// For Imagen models, convert response to Gemini format before translation
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
@@ -427,8 +434,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -447,8 +454,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, false)
@@ -477,6 +484,11 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -484,7 +496,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -496,10 +508,10 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
@@ -507,21 +519,21 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -532,8 +544,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -552,8 +564,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
@@ -581,6 +593,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
return nil, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -588,7 +605,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -600,17 +617,17 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
@@ -630,9 +647,9 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
@@ -644,8 +661,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -656,8 +673,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -676,8 +693,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
@@ -705,6 +722,11 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -712,7 +734,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -724,17 +746,17 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
@@ -754,9 +776,9 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
@@ -768,8 +790,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -812,6 +834,11 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -819,7 +846,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -831,10 +858,10 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
@@ -842,19 +869,19 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
@@ -896,6 +923,11 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -903,7 +935,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -915,10 +947,10 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
@@ -926,19 +958,19 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
@@ -1012,7 +1044,7 @@ func vertexBaseURL(location string) string {
}
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
// Use cloud-platform scope for Vertex AI.

View File

@@ -76,6 +76,9 @@ func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected responses-only registry model to use /responses")
}
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
t.Fatal("expected responses-only registry model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
@@ -83,15 +86,25 @@ func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *test
reg := registry.GetGlobalRegistry()
clientID := "github-copilot-test-client"
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
ID: "gpt-5.4",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
}})
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
{
ID: "gpt-5.4",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
{
ID: "gpt-5.4-mini",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
})
defer reg.UnregisterClient(clientID)
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
}
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {

View File

@@ -30,12 +30,20 @@ const (
gitLabChatEndpoint = "/api/v4/chat/completions"
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming"
gitLabContext1MBeta = "context-1m-2025-08-07"
gitLabNativeUserAgent = "CLIProxyAPIPlus/GitLab-Duo"
)
type GitLabExecutor struct {
cfg *config.Config
}
type gitLabCatalogModel struct {
ID string
DisplayName string
Provider string
}
type gitLabPrompt struct {
Instruction string
FileName string
@@ -53,6 +61,23 @@ type gitLabOpenAIStreamState struct {
Finished bool
}
var gitLabAgenticCatalog = []gitLabCatalogModel{
{ID: "duo-chat-gpt-5-1", DisplayName: "GitLab Duo (GPT-5.1)", Provider: "openai"},
{ID: "duo-chat-opus-4-6", DisplayName: "GitLab Duo (Claude Opus 4.6)", Provider: "anthropic"},
{ID: "duo-chat-opus-4-5", DisplayName: "GitLab Duo (Claude Opus 4.5)", Provider: "anthropic"},
{ID: "duo-chat-sonnet-4-6", DisplayName: "GitLab Duo (Claude Sonnet 4.6)", Provider: "anthropic"},
{ID: "duo-chat-sonnet-4-5", DisplayName: "GitLab Duo (Claude Sonnet 4.5)", Provider: "anthropic"},
{ID: "duo-chat-gpt-5-mini", DisplayName: "GitLab Duo (GPT-5 Mini)", Provider: "openai"},
{ID: "duo-chat-gpt-5-2", DisplayName: "GitLab Duo (GPT-5.2)", Provider: "openai"},
{ID: "duo-chat-gpt-5-2-codex", DisplayName: "GitLab Duo (GPT-5.2 Codex)", Provider: "openai"},
{ID: "duo-chat-gpt-5-codex", DisplayName: "GitLab Duo (GPT-5 Codex)", Provider: "openai"},
{ID: "duo-chat-haiku-4-5", DisplayName: "GitLab Duo (Claude Haiku 4.5)", Provider: "anthropic"},
}
var gitLabModelAliases = map[string]string{
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
}
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
return &GitLabExecutor{cfg: cfg}
}
@@ -249,12 +274,12 @@ func (e *GitLabExecutor) nativeGateway(
auth *cliproxyauth.Auth,
req cliproxyexecutor.Request,
) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, req.Model); ok {
nativeReq := req
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true
}
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, req.Model); ok {
nativeReq := req
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true
@@ -263,10 +288,10 @@ func (e *GitLabExecutor) nativeGateway(
}
func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, ""); ok {
return NewClaudeExecutor(e.cfg), nativeAuth
}
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, ""); ok {
return NewCodexExecutor(e.cfg), nativeAuth
}
return nil, nil
@@ -664,7 +689,7 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
if auth != nil {
util.ApplyCustomHeadersFromAttrs(req, auth.Attributes)
}
for key, value := range gitLabGatewayHeaders(auth) {
for key, value := range gitLabGatewayHeaders(auth, "") {
if key == "" || value == "" {
continue
}
@@ -672,34 +697,40 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
}
}
func gitLabGatewayHeaders(auth *cliproxyauth.Auth) map[string]string {
if auth == nil || auth.Metadata == nil {
return nil
}
raw, ok := auth.Metadata["duo_gateway_headers"]
if !ok {
return nil
}
func gitLabGatewayHeaders(auth *cliproxyauth.Auth, targetProvider string) map[string]string {
out := make(map[string]string)
switch typed := raw.(type) {
case map[string]string:
for key, value := range typed {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key != "" && value != "" {
out[key] = value
if auth != nil && auth.Metadata != nil {
raw, ok := auth.Metadata["duo_gateway_headers"]
if ok {
switch typed := raw.(type) {
case map[string]string:
for key, value := range typed {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key != "" && value != "" {
out[key] = value
}
}
case map[string]any:
for key, value := range typed {
key = strings.TrimSpace(key)
if key == "" {
continue
}
strValue := strings.TrimSpace(fmt.Sprint(value))
if strValue != "" {
out[key] = strValue
}
}
}
}
case map[string]any:
for key, value := range typed {
key = strings.TrimSpace(key)
if key == "" {
continue
}
strValue := strings.TrimSpace(fmt.Sprint(value))
if strValue != "" {
out[key] = strValue
}
}
if _, ok := out["User-Agent"]; !ok {
out["User-Agent"] = gitLabNativeUserAgent
}
if strings.EqualFold(strings.TrimSpace(targetProvider), "openai") {
if _, ok := out["anthropic-beta"]; !ok {
out["anthropic-beta"] = gitLabContext1MBeta
}
}
if len(out) == 0 {
@@ -989,8 +1020,8 @@ func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64)
return promptTokens, int64(completionCount)
}
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
if !gitLabUsesAnthropicGateway(auth) {
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
if !gitLabUsesAnthropicGateway(auth, requestedModel) {
return nil, false
}
baseURL := gitLabAnthropicGatewayBaseURL(auth)
@@ -1006,7 +1037,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
}
nativeAuth.Attributes["api_key"] = token
nativeAuth.Attributes["base_url"] = baseURL
for key, value := range gitLabGatewayHeaders(auth) {
nativeAuth.Attributes["gitlab_duo_force_context_1m"] = "true"
for key, value := range gitLabGatewayHeaders(auth, "anthropic") {
if key == "" || value == "" {
continue
}
@@ -1015,8 +1047,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
return nativeAuth, true
}
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
if !gitLabUsesOpenAIGateway(auth) {
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
if !gitLabUsesOpenAIGateway(auth, requestedModel) {
return nil, false
}
baseURL := gitLabOpenAIGatewayBaseURL(auth)
@@ -1032,7 +1064,7 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
}
nativeAuth.Attributes["api_key"] = token
nativeAuth.Attributes["base_url"] = baseURL
for key, value := range gitLabGatewayHeaders(auth) {
for key, value := range gitLabGatewayHeaders(auth, "openai") {
if key == "" || value == "" {
continue
}
@@ -1041,34 +1073,41 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
return nativeAuth, true
}
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth) bool {
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
if auth == nil || auth.Metadata == nil {
return false
}
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
if provider == "" {
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
provider = inferGitLabProviderFromModel(modelName)
}
provider := gitLabGatewayProvider(auth, requestedModel)
return provider == "anthropic" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
}
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth) bool {
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
if auth == nil || auth.Metadata == nil {
return false
}
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
if provider == "" {
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
provider = inferGitLabProviderFromModel(modelName)
}
provider := gitLabGatewayProvider(auth, requestedModel)
return provider == "openai" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
}
func gitLabGatewayProvider(auth *cliproxyauth.Auth, requestedModel string) string {
modelName := strings.TrimSpace(gitLabResolvedModel(auth, requestedModel))
if provider := inferGitLabProviderFromModel(modelName); provider != "" {
return provider
}
if auth == nil || auth.Metadata == nil {
return ""
}
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
if provider == "" {
provider = inferGitLabProviderFromModel(gitLabMetadataString(auth.Metadata, "model_name"))
}
return provider
}
func inferGitLabProviderFromModel(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
switch {
@@ -1151,6 +1190,9 @@ func gitLabBaseURL(auth *cliproxyauth.Auth) string {
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
if mapped, ok := gitLabModelAliases[strings.ToLower(requested)]; ok && strings.TrimSpace(mapped) != "" {
return mapped
}
return requested
}
if auth != nil && auth.Metadata != nil {
@@ -1277,8 +1319,8 @@ func gitLabAuthKind(method string) string {
}
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
models := make([]*registry.ModelInfo, 0, 4)
seen := make(map[string]struct{}, 4)
models := make([]*registry.ModelInfo, 0, len(gitLabAgenticCatalog)+4)
seen := make(map[string]struct{}, len(gitLabAgenticCatalog)+4)
addModel := func(id, displayName, provider string) {
id = strings.TrimSpace(id)
if id == "" {
@@ -1302,6 +1344,18 @@ func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
}
addModel("gitlab-duo", "GitLab Duo", "gitlab")
for _, model := range gitLabAgenticCatalog {
addModel(model.ID, model.DisplayName, model.Provider)
}
for alias, upstream := range gitLabModelAliases {
target := strings.TrimSpace(upstream)
displayName := "GitLab Duo Alias"
provider := strings.TrimSpace(inferGitLabProviderFromModel(target))
if provider != "" {
displayName = fmt.Sprintf("GitLab Duo Alias (%s)", provider)
}
addModel(alias, displayName, provider)
}
if auth == nil {
return models
}

View File

@@ -217,6 +217,69 @@ func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
}
}
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
var gotPath string
var gotModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
gotBetaHeader = r.Header.Get("anthropic-beta")
gotUserAgent = r.Header.Get("User-Agent")
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "duo-chat-gpt-5-codex",
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/v1/proxy/openai/v1/responses" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
}
if gotBetaHeader != gitLabContext1MBeta {
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
}
if gotUserAgent != gitLabNativeUserAgent {
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
}
if gotModel != "duo-chat-gpt-5-codex" {
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
}
}
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
@@ -251,13 +314,12 @@ func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
ID: "gitlab-auth.json",
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"refresh_token": "oauth-refresh",
"oauth_client_id": "client-id",
"oauth_client_secret": "client-secret",
"auth_method": "oauth",
"oauth_expires_at": "2000-01-01T00:00:00Z",
"base_url": srv.URL,
"access_token": "oauth-access",
"refresh_token": "oauth-refresh",
"oauth_client_id": "client-id",
"auth_method": "oauth",
"oauth_expires_at": "2000-01-01T00:00:00Z",
},
}
@@ -397,9 +459,11 @@ func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
}
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
var gotPath string
var gotPath, gotBetaHeader, gotUserAgent string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotBetaHeader = r.Header.Get("Anthropic-Beta")
gotUserAgent = r.Header.Get("User-Agent")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: message_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
@@ -441,6 +505,12 @@ func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
}
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
}
if gotUserAgent != gitLabNativeUserAgent {
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
}
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
}

View File

@@ -1,11 +1,11 @@
package executor
package helps
import (
"sync"
"time"
)
type codexCache struct {
type CodexCache struct {
ID string
Expire time.Time
}
@@ -13,7 +13,7 @@ type codexCache struct {
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
// Protected by codexCacheMu. Entries expire after 1 hour.
var (
codexCacheMap = make(map[string]codexCache)
codexCacheMap = make(map[string]CodexCache)
codexCacheMu sync.RWMutex
)
@@ -50,20 +50,20 @@ func purgeExpiredCodexCache() {
}
}
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
func getCodexCache(key string) (codexCache, bool) {
// GetCodexCache retrieves a cached entry, returning ok=false if not found or expired.
func GetCodexCache(key string) (CodexCache, bool) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.RLock()
cache, ok := codexCacheMap[key]
codexCacheMu.RUnlock()
if !ok || cache.Expire.Before(time.Now()) {
return codexCache{}, false
return CodexCache{}, false
}
return cache, true
}
// setCodexCache stores a cache entry.
func setCodexCache(key string, cache codexCache) {
// SetCodexCache stores a cache entry.
func SetCodexCache(key string, cache CodexCache) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.Lock()
codexCacheMap[key] = cache

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"crypto/sha256"
@@ -32,7 +32,7 @@ var (
claudeDeviceProfileCacheMu sync.RWMutex
claudeDeviceProfileCacheCleanupOnce sync.Once
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
ClaudeDeviceProfileBeforeCandidateStore func(ClaudeDeviceProfile)
)
type claudeCLIVersion struct {
@@ -63,29 +63,43 @@ func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
}
}
type claudeDeviceProfile struct {
type ClaudeDeviceProfile struct {
UserAgent string
PackageVersion string
RuntimeVersion string
OS string
Arch string
Version claudeCLIVersion
HasVersion bool
version claudeCLIVersion
hasVersion bool
}
type claudeDeviceProfileCacheEntry struct {
profile claudeDeviceProfile
profile ClaudeDeviceProfile
expire time.Time
}
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
func ClaudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
return false
}
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
}
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
func ResetClaudeDeviceProfileCache() {
claudeDeviceProfileCacheMu.Lock()
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu.Unlock()
}
func MapStainlessOS() string {
return mapStainlessOS()
}
func MapStainlessArch() string {
return mapStainlessArch()
}
func defaultClaudeDeviceProfile(cfg *config.Config) ClaudeDeviceProfile {
hdrDefault := func(cfgVal, fallback string) string {
if strings.TrimSpace(cfgVal) != "" {
return strings.TrimSpace(cfgVal)
@@ -98,7 +112,7 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
hd = cfg.ClaudeHeaderDefaults
}
profile := claudeDeviceProfile{
profile := ClaudeDeviceProfile{
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
@@ -106,8 +120,8 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
}
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
profile.Version = version
profile.HasVersion = true
profile.version = version
profile.hasVersion = true
}
return profile
}
@@ -162,17 +176,17 @@ func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
}
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
if candidate.UserAgent == "" || !candidate.HasVersion {
func shouldUpgradeClaudeDeviceProfile(candidate, current ClaudeDeviceProfile) bool {
if candidate.UserAgent == "" || !candidate.hasVersion {
return false
}
if current.UserAgent == "" || !current.HasVersion {
if current.UserAgent == "" || !current.hasVersion {
return true
}
return candidate.Version.Compare(current.Version) > 0
return candidate.version.Compare(current.version) > 0
}
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
func pinClaudeDeviceProfilePlatform(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
profile.OS = baseline.OS
profile.Arch = baseline.Arch
return profile
@@ -180,38 +194,38 @@ func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claud
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
// baseline platform and enforces the baseline software fingerprint as a floor.
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
func normalizeClaudeDeviceProfile(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
if profile.UserAgent == "" || !profile.hasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
profile.UserAgent = baseline.UserAgent
profile.PackageVersion = baseline.PackageVersion
profile.RuntimeVersion = baseline.RuntimeVersion
profile.Version = baseline.Version
profile.HasVersion = baseline.HasVersion
profile.version = baseline.version
profile.hasVersion = baseline.hasVersion
}
return profile
}
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (ClaudeDeviceProfile, bool) {
if headers == nil {
return claudeDeviceProfile{}, false
return ClaudeDeviceProfile{}, false
}
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
version, ok := parseClaudeCLIVersion(userAgent)
if !ok {
return claudeDeviceProfile{}, false
return ClaudeDeviceProfile{}, false
}
baseline := defaultClaudeDeviceProfile(cfg)
profile := claudeDeviceProfile{
profile := ClaudeDeviceProfile{
UserAgent: userAgent,
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
Version: version,
HasVersion: true,
version: version,
hasVersion: true,
}
return profile, true
}
@@ -263,7 +277,7 @@ func purgeExpiredClaudeDeviceProfiles() {
claudeDeviceProfileCacheMu.Unlock()
}
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
func ResolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) ClaudeDeviceProfile {
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
@@ -283,8 +297,8 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
claudeDeviceProfileCacheMu.RUnlock()
if hasCandidate {
if claudeDeviceProfileBeforeCandidateStore != nil {
claudeDeviceProfileBeforeCandidateStore(candidate)
if ClaudeDeviceProfileBeforeCandidateStore != nil {
ClaudeDeviceProfileBeforeCandidateStore(candidate)
}
claudeDeviceProfileCacheMu.Lock()
@@ -324,7 +338,7 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
return baseline
}
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
func ApplyClaudeDeviceProfileHeaders(r *http.Request, profile ClaudeDeviceProfile) {
if r == nil {
return
}
@@ -344,7 +358,17 @@ func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfil
r.Header.Set("X-Stainless-Arch", profile.Arch)
}
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
// DefaultClaudeVersion returns the version string (e.g. "2.1.63") from the
// current baseline device profile. It extracts the version from the User-Agent.
func DefaultClaudeVersion(cfg *config.Config) string {
profile := defaultClaudeDeviceProfile(cfg)
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
return strconv.Itoa(version.major) + "." + strconv.Itoa(version.minor) + "." + strconv.Itoa(version.patch)
}
return "2.1.63"
}
func ApplyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
if r == nil {
return
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"regexp"
@@ -18,9 +18,9 @@ type SensitiveWordMatcher struct {
regex *regexp.Regexp
}
// buildSensitiveWordMatcher compiles a regex from the word list.
// BuildSensitiveWordMatcher compiles a regex from the word list.
// Words are sorted by length (longest first) for proper matching.
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
func BuildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
if len(words) == 0 {
return nil
}
@@ -81,9 +81,9 @@ func (m *SensitiveWordMatcher) obfuscateText(text string) string {
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
}
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
// ObfuscateSensitiveWords processes the payload and obfuscates sensitive words
// in system blocks and message content.
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
func ObfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
if matcher == nil || matcher.regex == nil {
return payload
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"crypto/rand"
@@ -28,9 +28,17 @@ func isValidUserID(userID string) bool {
return userIDPattern.MatchString(userID)
}
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
func GenerateFakeUserID() string {
return generateFakeUserID()
}
func IsValidUserID(userID string) bool {
return isValidUserID(userID)
}
// ShouldCloak determines if request should be cloaked based on config and client User-Agent.
// Returns true if cloaking should be applied.
func shouldCloak(cloakMode string, userAgent string) bool {
func ShouldCloak(cloakMode string, userAgent string) bool {
switch strings.ToLower(cloakMode) {
case "always":
return true

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"bytes"
@@ -6,6 +6,7 @@ import (
"fmt"
"html"
"net/http"
"net/url"
"sort"
"strings"
"time"
@@ -19,13 +20,14 @@ import (
)
const (
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
apiRequestKey = "API_REQUEST"
apiResponseKey = "API_RESPONSE"
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
apiRequestKey = "API_REQUEST"
apiResponseKey = "API_RESPONSE"
apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE"
)
// upstreamRequestLog captures the outbound upstream request details for logging.
type upstreamRequestLog struct {
// UpstreamRequestLog captures the outbound upstream request details for logging.
type UpstreamRequestLog struct {
URL string
Method string
Headers http.Header
@@ -46,11 +48,12 @@ type upstreamAttempt struct {
headersWritten bool
bodyStarted bool
bodyHasContent bool
prevWasSSEEvent bool
errorWritten bool
}
// recordAPIRequest stores the upstream request metadata in Gin context for request logging.
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
// RecordAPIRequest stores the upstream request metadata in Gin context for request logging.
func RecordAPIRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
if cfg == nil || !cfg.RequestLog {
return
}
@@ -96,8 +99,8 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
updateAggregatedRequest(ginCtx, attempts)
}
// recordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
// RecordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
func RecordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
if cfg == nil || !cfg.RequestLog {
return
}
@@ -122,8 +125,8 @@ func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status i
updateAggregatedResponse(ginCtx, attempts)
}
// recordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
// RecordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
func RecordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
if cfg == nil || !cfg.RequestLog || err == nil {
return
}
@@ -147,8 +150,8 @@ func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error)
updateAggregatedResponse(ginCtx, attempts)
}
// appendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
// AppendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
@@ -173,15 +176,157 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
attempt.response.WriteString("Body:\n")
attempt.bodyStarted = true
}
currentChunkIsSSEEvent := bytes.HasPrefix(data, []byte("event:"))
currentChunkIsSSEData := bytes.HasPrefix(data, []byte("data:"))
if attempt.bodyHasContent {
attempt.response.WriteString("\n\n")
separator := "\n\n"
if attempt.prevWasSSEEvent && currentChunkIsSSEData {
separator = "\n"
}
attempt.response.WriteString(separator)
}
attempt.response.WriteString(string(data))
attempt.bodyHasContent = true
attempt.prevWasSSEEvent = currentChunkIsSSEEvent
updateAggregatedResponse(ginCtx, attempts)
}
// RecordAPIWebsocketRequest stores an upstream websocket request event in Gin context.
func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.request\n")
if info.URL != "" {
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
}
if auth := formatAuthInfo(info); auth != "" {
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
}
builder.WriteString("Headers:\n")
writeHeaders(builder, info.Headers)
builder.WriteString("\nBody:\n")
if len(info.Body) > 0 {
builder.Write(info.Body)
} else {
builder.WriteString("<empty>")
}
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.handshake\n")
if status > 0 {
builder.WriteString(fmt.Sprintf("Status: %d\n", status))
}
builder.WriteString("Headers:\n")
writeHeaders(builder, headers)
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
RecordAPIRequest(ctx, cfg, info)
RecordAPIResponseMetadata(ctx, cfg, status, headers)
AppendAPIResponseChunk(ctx, cfg, body)
}
// WebsocketUpgradeRequestURL converts a websocket URL back to its HTTP handshake URL for logging.
func WebsocketUpgradeRequestURL(rawURL string) string {
trimmedURL := strings.TrimSpace(rawURL)
if trimmedURL == "" {
return ""
}
parsed, err := url.Parse(trimmedURL)
if err != nil {
return trimmedURL
}
switch strings.ToLower(parsed.Scheme) {
case "ws":
parsed.Scheme = "http"
case "wss":
parsed.Scheme = "https"
}
return parsed.String()
}
// AppendAPIWebsocketResponse stores an upstream websocket response frame in Gin context.
func AppendAPIWebsocketResponse(ctx context.Context, cfg *config.Config, payload []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
data := bytes.TrimSpace(payload)
if len(data) == 0 {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
markAPIResponseTimestamp(ginCtx)
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.response\n")
builder.Write(data)
builder.WriteString("\n")
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
// RecordAPIWebsocketError stores an upstream websocket error event in Gin context.
func RecordAPIWebsocketError(ctx context.Context, cfg *config.Config, stage string, err error) {
if cfg == nil || !cfg.RequestLog || err == nil {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
markAPIResponseTimestamp(ginCtx)
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
builder.WriteString("Event: api.websocket.error\n")
if trimmed := strings.TrimSpace(stage); trimmed != "" {
builder.WriteString(fmt.Sprintf("Stage: %s\n", trimmed))
}
builder.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
}
func ginContextFrom(ctx context.Context) *gin.Context {
ginCtx, _ := ctx.Value("gin").(*gin.Context)
return ginCtx
@@ -259,6 +404,40 @@ func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt)
ginCtx.Set(apiResponseKey, []byte(builder.String()))
}
func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) {
if ginCtx == nil {
return
}
data := bytes.TrimSpace(chunk)
if len(data) == 0 {
return
}
if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
combined := make([]byte, 0, len(existingBytes)+len(data)+2)
combined = append(combined, existingBytes...)
if !bytes.HasSuffix(existingBytes, []byte("\n")) {
combined = append(combined, '\n')
}
combined = append(combined, '\n')
combined = append(combined, data...)
ginCtx.Set(apiWebsocketTimelineKey, combined)
return
}
}
ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data))
}
func markAPIResponseTimestamp(ginCtx *gin.Context) {
if ginCtx == nil {
return
}
if _, exists := ginCtx.Get("API_RESPONSE_TIMESTAMP"); exists {
return
}
ginCtx.Set("API_RESPONSE_TIMESTAMP", time.Now())
}
func writeHeaders(builder *strings.Builder, headers http.Header) {
if builder == nil {
return
@@ -285,7 +464,7 @@ func writeHeaders(builder *strings.Builder, headers http.Header) {
}
}
func formatAuthInfo(info upstreamRequestLog) string {
func formatAuthInfo(info UpstreamRequestLog) string {
var parts []string
if trimmed := strings.TrimSpace(info.Provider); trimmed != "" {
parts = append(parts, fmt.Sprintf("provider=%s", trimmed))
@@ -321,7 +500,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
return strings.Join(parts, ", ")
}
func summarizeErrorBody(contentType string, body []byte) string {
func SummarizeErrorBody(contentType string, body []byte) string {
isHTML := strings.Contains(strings.ToLower(contentType), "text/html")
if !isHTML {
trimmed := bytes.TrimSpace(bytes.ToLower(body))
@@ -379,7 +558,7 @@ func extractJSONErrorMessage(body []byte) string {
// logWithRequestID returns a logrus Entry with request_id field populated from context.
// If no request ID is found in context, it returns the standard logger.
func logWithRequestID(ctx context.Context) *log.Entry {
func LogWithRequestID(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"encoding/json"
@@ -11,12 +11,12 @@ import (
"github.com/tidwall/sjson"
)
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied. Defaults are checked
// against the original payload when provided. requestedModel carries the client-visible
// model name before alias resolution so payload rules can target aliases precisely.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
if cfg == nil || len(payload) == 0 {
return payload
}
@@ -244,7 +244,7 @@ func payloadRawValue(value any) ([]byte, bool) {
}
}
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
fallback = strings.TrimSpace(fallback)
if len(opts.Metadata) == 0 {
return fallback

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"context"
@@ -19,7 +19,7 @@ var (
httpClientCacheMutex sync.RWMutex
)
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// NewProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// 1. Use auth.ProxyURL if configured (highest priority)
// 2. Use cfg.ProxyURL if auth proxy is not configured
// 3. Use RoundTripper from context if neither are configured
@@ -34,7 +34,7 @@ var (
//
// Returns:
// - *http.Client: An HTTP client with configured proxy or transport
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
// Priority 1: Use auth.ProxyURL if configured
var proxyURL string
if auth != nil {
@@ -46,23 +46,18 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
// Build cache key from proxy URL (empty string for no proxy)
cacheKey := proxyURL
// Check cache first
httpClientCacheMutex.RLock()
if cachedClient, ok := httpClientCache[cacheKey]; ok {
httpClientCacheMutex.RUnlock()
// Return a wrapper with the requested timeout but shared transport
if timeout > 0 {
return &http.Client{
Transport: cachedClient.Transport,
Timeout: timeout,
// If we have a proxy URL configured, try cache first to reuse TCP/TLS connections.
if proxyURL != "" {
httpClientCacheMutex.RLock()
if cachedClient, ok := httpClientCache[proxyURL]; ok {
httpClientCacheMutex.RUnlock()
if timeout > 0 {
return &http.Client{Transport: cachedClient.Transport, Timeout: timeout}
}
return cachedClient
}
return cachedClient
httpClientCacheMutex.RUnlock()
}
httpClientCacheMutex.RUnlock()
// Create new client
httpClient := &http.Client{}
@@ -77,7 +72,7 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = transport
// Cache the client
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCache[proxyURL] = httpClient
httpClientCacheMutex.Unlock()
return httpClient
}
@@ -90,13 +85,6 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = rt
}
// Cache the client for no-proxy case
if proxyURL == "" {
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
}
return httpClient
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"context"
@@ -13,7 +13,7 @@ import (
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
client := newProxyAwareHTTPClient(
client := NewProxyAwareHTTPClient(
context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},

View File

@@ -0,0 +1,92 @@
package helps
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
"github.com/google/uuid"
)
type sessionIDCacheEntry struct {
value string
expire time.Time
}
var (
sessionIDCache = make(map[string]sessionIDCacheEntry)
sessionIDCacheMu sync.RWMutex
sessionIDCacheCleanupOnce sync.Once
)
const (
sessionIDTTL = time.Hour
sessionIDCacheCleanupPeriod = 15 * time.Minute
)
func startSessionIDCacheCleanup() {
go func() {
ticker := time.NewTicker(sessionIDCacheCleanupPeriod)
defer ticker.Stop()
for range ticker.C {
purgeExpiredSessionIDs()
}
}()
}
func purgeExpiredSessionIDs() {
now := time.Now()
sessionIDCacheMu.Lock()
for key, entry := range sessionIDCache {
if !entry.expire.After(now) {
delete(sessionIDCache, key)
}
}
sessionIDCacheMu.Unlock()
}
func sessionIDCacheKey(apiKey string) string {
sum := sha256.Sum256([]byte(apiKey))
return hex.EncodeToString(sum[:])
}
// CachedSessionID returns a stable session UUID per apiKey, refreshing the TTL on each access.
func CachedSessionID(apiKey string) string {
if apiKey == "" {
return uuid.New().String()
}
sessionIDCacheCleanupOnce.Do(startSessionIDCacheCleanup)
key := sessionIDCacheKey(apiKey)
now := time.Now()
sessionIDCacheMu.RLock()
entry, ok := sessionIDCache[key]
valid := ok && entry.value != "" && entry.expire.After(now)
sessionIDCacheMu.RUnlock()
if valid {
sessionIDCacheMu.Lock()
entry = sessionIDCache[key]
if entry.value != "" && entry.expire.After(now) {
entry.expire = now.Add(sessionIDTTL)
sessionIDCache[key] = entry
sessionIDCacheMu.Unlock()
return entry.value
}
sessionIDCacheMu.Unlock()
}
newID := uuid.New().String()
sessionIDCacheMu.Lock()
entry, ok = sessionIDCache[key]
if !ok || entry.value == "" || !entry.expire.After(now) {
entry.value = newID
}
entry.expire = now.Add(sessionIDTTL)
sessionIDCache[key] = entry
sessionIDCacheMu.Unlock()
return entry.value
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"

View File

@@ -1,9 +1,7 @@
package executor
package helps
import (
"fmt"
"regexp"
"strconv"
"strings"
"sync"
@@ -11,100 +9,80 @@ import (
"github.com/tiktoken-go/tokenizer"
)
// tokenizerCache stores tokenizer instances to avoid repeated creation
// tokenizerCache stores tokenizer instances to avoid repeated creation.
var tokenizerCache sync.Map
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
type TokenizerWrapper struct {
Codec tokenizer.Codec
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
type adjustedTokenizer struct {
tokenizer.Codec
adjustmentFactor float64
}
// Count returns the token count with adjustment factor applied
func (tw *TokenizerWrapper) Count(text string) (int, error) {
func (tw *adjustedTokenizer) Count(text string) (int, error) {
count, err := tw.Codec.Count(text)
if err != nil {
return 0, err
}
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
return int(float64(count) * tw.AdjustmentFactor), nil
if tw.adjustmentFactor > 0 && tw.adjustmentFactor != 1.0 {
return int(float64(count) * tw.adjustmentFactor), nil
}
return count, nil
}
// getTokenizer returns a cached tokenizer for the given model.
// This improves performance by avoiding repeated tokenizer creation.
func getTokenizer(model string) (*TokenizerWrapper, error) {
// Check cache first
if cached, ok := tokenizerCache.Load(model); ok {
return cached.(*TokenizerWrapper), nil
// TokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
// For Claude-like models, it applies an adjustment factor since tiktoken may underestimate token counts.
func TokenizerForModel(model string) (tokenizer.Codec, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
if cached, ok := tokenizerCache.Load(sanitized); ok {
return cached.(tokenizer.Codec), nil
}
// Cache miss, create new tokenizer
wrapper, err := tokenizerForModel(model)
enc, err := tokenizerForModel(sanitized)
if err != nil {
return nil, err
}
// Store in cache (use LoadOrStore to handle race conditions)
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
return actual.(*TokenizerWrapper), nil
actual, _ := tokenizerCache.LoadOrStore(sanitized, enc)
return actual.(tokenizer.Codec), nil
}
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
func tokenizerForModel(sanitized string) (tokenizer.Codec, error) {
if sanitized == "" {
return tokenizer.Get(tokenizer.Cl100kBase)
}
// Claude models use cl100k_base with 1.1 adjustment factor
// because tiktoken may underestimate Claude's actual token count
// Claude models use cl100k_base with an adjustment factor because tiktoken may underestimate.
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
return &adjustedTokenizer{Codec: enc, adjustmentFactor: 1.1}, nil
}
var enc tokenizer.Codec
var err error
switch {
case sanitized == "":
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5.2"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5.1"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
return tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
enc, err = tokenizer.ForModel(tokenizer.GPT41)
return tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
return tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
enc, err = tokenizer.ForModel(tokenizer.GPT4)
return tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
return tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"):
enc, err = tokenizer.ForModel(tokenizer.O1)
return tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"):
enc, err = tokenizer.ForModel(tokenizer.O3)
return tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"):
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
return tokenizer.ForModel(tokenizer.O4Mini)
default:
enc, err = tokenizer.Get(tokenizer.O200kBase)
return tokenizer.Get(tokenizer.O200kBase)
}
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
}
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
// CountOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func CountOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
@@ -128,22 +106,15 @@ func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
return 0, nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
return int64(count), nil
}
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
// This handles Claude's message format with system, messages, and tools.
// Image tokens are estimated based on image dimensions when available.
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
// CountClaudeChatTokens approximates prompt tokens for Claude API chat payloads.
func CountClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
@@ -153,185 +124,25 @@ func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
root := gjson.ParseBytes(payload)
segments := make([]string, 0, 32)
imageTokens := 0
// Collect system prompt (can be string or array of content blocks)
collectClaudeSystem(root.Get("system"), &segments)
// Collect messages
collectClaudeMessages(root.Get("messages"), &segments)
// Collect tools
collectClaudeContent(root.Get("system"), &segments, &imageTokens)
collectClaudeMessages(root.Get("messages"), &segments, &imageTokens)
collectClaudeTools(root.Get("tools"), &segments)
joined := strings.TrimSpace(strings.Join(segments, "\n"))
if joined == "" {
return 0, nil
return int64(imageTokens), nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
return int64(count + imageTokens), nil
}
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
// extractImageTokens extracts image token estimates from placeholder text.
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
func extractImageTokens(text string) int {
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
total := 0
for _, match := range matches {
if len(match) > 1 {
if tokens, err := strconv.Atoi(match[1]); err == nil {
total += tokens
}
}
}
return total
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image)
return 1000
}
tokens := int(width * height / 750)
// Apply bounds
if tokens < 85 {
tokens = 85
}
if tokens > 1590 {
tokens = 1590
}
return tokens
}
// collectClaudeSystem extracts text from Claude's system field.
// System can be a string or an array of content blocks.
func collectClaudeSystem(system gjson.Result, segments *[]string) {
if !system.Exists() {
return
}
if system.Type == gjson.String {
addIfNotEmpty(segments, system.String())
return
}
if system.IsArray() {
system.ForEach(func(_, block gjson.Result) bool {
blockType := block.Get("type").String()
if blockType == "text" || blockType == "" {
addIfNotEmpty(segments, block.Get("text").String())
}
// Also handle plain string blocks
if block.Type == gjson.String {
addIfNotEmpty(segments, block.String())
}
return true
})
}
}
// collectClaudeMessages extracts text from Claude's messages array.
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments)
return true
})
}
// collectClaudeContent extracts text from Claude's content field.
// Content can be a string or an array of content blocks.
// For images, estimates token count based on dimensions when available.
func collectClaudeContent(content gjson.Result, segments *[]string) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
// Estimate image tokens based on dimensions if available
source := part.Get("source")
if source.Exists() {
width := source.Get("width").Float()
height := source.Get("height").Float()
if width > 0 && height > 0 {
tokens := estimateImageTokens(width, height)
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
} else {
// No dimensions available, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
} else {
// No source info, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
// For unknown types, try to extract any text content
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
}
}
// collectClaudeTools extracts text from Claude's tools array.
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
func buildOpenAIUsageJSON(count int64) []byte {
// BuildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
func BuildOpenAIUsageJSON(count int64) []byte {
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
}
@@ -390,6 +201,10 @@ func collectOpenAIContent(content gjson.Result, segments *[]string) {
}
}
func CollectOpenAIContent(content gjson.Result, segments *[]string) {
collectOpenAIContent(content, segments)
}
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
if !calls.Exists() || !calls.IsArray() {
return
@@ -487,6 +302,98 @@ func appendToolPayload(tool gjson.Result, segments *[]string) {
}
}
func collectClaudeMessages(messages gjson.Result, segments *[]string, imageTokens *int) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments, imageTokens)
return true
})
}
func collectClaudeContent(content gjson.Result, segments *[]string, imageTokens *int) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
source := part.Get("source")
width := source.Get("width").Float()
height := source.Get("height").Float()
if imageTokens != nil {
*imageTokens += estimateImageTokens(width, height)
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments, imageTokens)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
return
}
if content.Type == gjson.JSON {
addIfNotEmpty(segments, content.Raw)
}
}
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image).
return 1000
}
tokens := int(width * height / 750)
if tokens < 85 {
return 85
}
if tokens > 1590 {
return 1590
}
return tokens
}
func addIfNotEmpty(segments *[]string, value string) {
if segments == nil {
return

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"bytes"
@@ -15,7 +15,7 @@ import (
"github.com/tidwall/sjson"
)
type usageReporter struct {
type UsageReporter struct {
provider string
model string
authID string
@@ -26,9 +26,9 @@ type usageReporter struct {
once sync.Once
}
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
apiKey := apiKeyFromContext(ctx)
reporter := &usageReporter{
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
apiKey := APIKeyFromContext(ctx)
reporter := &UsageReporter{
provider: provider,
model: model,
requestedAt: time.Now(),
@@ -42,24 +42,24 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox
return reporter
}
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
func (r *UsageReporter) Publish(ctx context.Context, detail usage.Detail) {
r.publishWithOutcome(ctx, detail, false)
}
func (r *usageReporter) publishFailure(ctx context.Context) {
func (r *UsageReporter) PublishFailure(ctx context.Context) {
r.publishWithOutcome(ctx, usage.Detail{}, true)
}
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
func (r *UsageReporter) TrackFailure(ctx context.Context, errPtr *error) {
if r == nil || errPtr == nil {
return
}
if *errPtr != nil {
r.publishFailure(ctx)
r.PublishFailure(ctx)
}
}
func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
func (r *UsageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
if r == nil {
return
}
@@ -81,7 +81,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
// It is safe to call multiple times; only the first call wins due to once.Do.
// This is used to ensure request counting even when upstream responses do not
// include any usage fields (tokens), especially for streaming paths.
func (r *usageReporter) ensurePublished(ctx context.Context) {
func (r *UsageReporter) EnsurePublished(ctx context.Context) {
if r == nil {
return
}
@@ -90,7 +90,7 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
})
}
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
if r == nil {
return usage.Record{Detail: detail, Failed: failed}
}
@@ -108,7 +108,7 @@ func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Reco
}
}
func (r *usageReporter) latency() time.Duration {
func (r *UsageReporter) latency() time.Duration {
if r == nil || r.requestedAt.IsZero() {
return 0
}
@@ -119,7 +119,7 @@ func (r *usageReporter) latency() time.Duration {
return latency
}
func apiKeyFromContext(ctx context.Context) string {
func APIKeyFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
@@ -184,7 +184,7 @@ func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
return ""
}
func parseCodexUsage(data []byte) (usage.Detail, bool) {
func ParseCodexUsage(data []byte) (usage.Detail, bool) {
usageNode := gjson.ParseBytes(data).Get("response.usage")
if !usageNode.Exists() {
return usage.Detail{}, false
@@ -203,7 +203,7 @@ func parseCodexUsage(data []byte) (usage.Detail, bool) {
return detail, true
}
func parseOpenAIUsage(data []byte) usage.Detail {
func ParseOpenAIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
@@ -238,7 +238,7 @@ func parseOpenAIUsage(data []byte) usage.Detail {
return detail
}
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
func ParseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -247,59 +247,40 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
if !usageNode.Exists() {
return usage.Detail{}, false
}
inputNode := usageNode.Get("prompt_tokens")
if !inputNode.Exists() {
inputNode = usageNode.Get("input_tokens")
}
outputNode := usageNode.Get("completion_tokens")
if !outputNode.Exists() {
outputNode = usageNode.Get("output_tokens")
}
detail := usage.Detail{
InputTokens: usageNode.Get("prompt_tokens").Int(),
OutputTokens: usageNode.Get("completion_tokens").Int(),
InputTokens: inputNode.Int(),
OutputTokens: outputNode.Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
if !cached.Exists() {
cached = usageNode.Get("input_tokens_details.cached_tokens")
}
if cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
if !reasoning.Exists() {
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
}
if reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail, true
}
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
}
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
}
return parseOpenAIResponsesUsageDetail(usageNode)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
usageNode := gjson.GetBytes(payload, "usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
return parseOpenAIResponsesUsageDetail(usageNode), true
}
func parseClaudeUsage(data []byte) usage.Detail {
func ParseClaudeUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
@@ -317,7 +298,7 @@ func parseClaudeUsage(data []byte) usage.Detail {
return detail
}
func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
func ParseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -352,7 +333,7 @@ func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
return detail
}
func parseGeminiCLIUsage(data []byte) usage.Detail {
func ParseGeminiCLIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
if !node.Exists() {
@@ -364,7 +345,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiUsage(data []byte) usage.Detail {
func ParseGeminiUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("usageMetadata")
if !node.Exists() {
@@ -376,7 +357,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
func ParseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -391,7 +372,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
return parseGeminiFamilyUsageDetail(node), true
}
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
func ParseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -406,7 +387,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
return parseGeminiFamilyUsageDetail(node), true
}
func parseAntigravityUsage(data []byte) usage.Detail {
func ParseAntigravityUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
if !node.Exists() {
@@ -421,7 +402,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node)
}
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
func ParseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -590,6 +571,10 @@ func isStopChunkWithoutUsage(jsonBytes []byte) bool {
return !hasUsageMetadata(jsonBytes)
}
func JSONPayload(line []byte) []byte {
return jsonPayload(line)
}
func jsonPayload(line []byte) []byte {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 {

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"testing"
@@ -9,7 +9,7 @@ import (
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
detail := parseOpenAIUsage(data)
detail := ParseOpenAIUsage(data)
if detail.InputTokens != 1 {
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
}
@@ -29,7 +29,7 @@ func TestParseOpenAIUsageChatCompletions(t *testing.T) {
func TestParseOpenAIUsageResponses(t *testing.T) {
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
detail := parseOpenAIUsage(data)
detail := ParseOpenAIUsage(data)
if detail.InputTokens != 10 {
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
}
@@ -48,7 +48,7 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
}
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
reporter := &usageReporter{
reporter := &UsageReporter{
provider: "openai",
model: "gpt-5.4",
requestedAt: time.Now().Add(-1500 * time.Millisecond),

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"crypto/sha256"
@@ -49,7 +49,7 @@ func userIDCacheKey(apiKey string) string {
return hex.EncodeToString(sum[:])
}
func cachedUserID(apiKey string) string {
func CachedUserID(apiKey string) string {
if apiKey == "" {
return generateFakeUserID()
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"testing"
@@ -14,8 +14,8 @@ func resetUserIDCache() {
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-1")
first := CachedUserID("api-key-1")
second := CachedUserID("api-key-1")
if first == "" {
t.Fatal("expected generated user_id to be non-empty")
@@ -28,7 +28,7 @@ func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
resetUserIDCache()
expiredID := cachedUserID("api-key-expired")
expiredID := CachedUserID("api-key-expired")
cacheKey := userIDCacheKey("api-key-expired")
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
@@ -37,7 +37,7 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
}
userIDCacheMu.Unlock()
newID := cachedUserID("api-key-expired")
newID := CachedUserID("api-key-expired")
if newID == expiredID {
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
}
@@ -49,8 +49,8 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-2")
first := CachedUserID("api-key-1")
second := CachedUserID("api-key-2")
if first == second {
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
@@ -61,7 +61,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
resetUserIDCache()
key := "api-key-renew"
id := cachedUserID(key)
id := CachedUserID(key)
cacheKey := userIDCacheKey(key)
soon := time.Now()
@@ -72,7 +72,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
}
userIDCacheMu.Unlock()
if refreshed := cachedUserID(key); refreshed != id {
if refreshed := CachedUserID(key); refreshed != id {
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
}

View File

@@ -0,0 +1,188 @@
package helps
import (
"net"
"net/http"
"strings"
"sync"
"time"
tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/proxy"
)
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
type utlsRoundTripper struct {
mu sync.Mutex
connections map[string]*http2.ClientConn
pending map[string]*sync.Cond
dialer proxy.Dialer
}
func newUtlsRoundTripper(proxyURL string) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct
if proxyURL != "" {
proxyDialer, mode, errBuild := proxyutil.BuildDialer(proxyURL)
if errBuild != nil {
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyURL, errBuild)
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
dialer = proxyDialer
}
}
return &utlsRoundTripper{
connections: make(map[string]*http2.ClientConn),
pending: make(map[string]*sync.Cond),
dialer: dialer,
}
}
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
t.mu.Lock()
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
t.mu.Unlock()
return h2Conn, nil
}
if cond, ok := t.pending[host]; ok {
cond.Wait()
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
t.mu.Unlock()
return h2Conn, nil
}
}
cond := sync.NewCond(&t.mu)
t.pending[host] = cond
t.mu.Unlock()
h2Conn, err := t.createConnection(host, addr)
t.mu.Lock()
defer t.mu.Unlock()
delete(t.pending, host)
cond.Broadcast()
if err != nil {
return nil, err
}
t.connections[host] = h2Conn
return h2Conn, nil
}
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
conn, err := t.dialer.Dial("tcp", addr)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{ServerName: host}
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, err
}
tr := &http2.Transport{}
h2Conn, err := tr.NewClientConn(tlsConn)
if err != nil {
tlsConn.Close()
return nil, err
}
return h2Conn, nil
}
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
hostname := req.URL.Hostname()
port := req.URL.Port()
if port == "" {
port = "443"
}
addr := net.JoinHostPort(hostname, port)
h2Conn, err := t.getOrCreateConnection(hostname, addr)
if err != nil {
return nil, err
}
resp, err := h2Conn.RoundTrip(req)
if err != nil {
t.mu.Lock()
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
delete(t.connections, hostname)
}
t.mu.Unlock()
return nil, err
}
return resp, nil
}
// anthropicHosts contains the hosts that should use utls Chrome TLS fingerprint.
var anthropicHosts = map[string]struct{}{
"api.anthropic.com": {},
}
// fallbackRoundTripper uses utls for Anthropic HTTPS hosts and falls back to
// standard transport for all other requests (non-HTTPS or non-Anthropic hosts).
type fallbackRoundTripper struct {
utls *utlsRoundTripper
fallback http.RoundTripper
}
func (f *fallbackRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme == "https" {
if _, ok := anthropicHosts[strings.ToLower(req.URL.Hostname())]; ok {
return f.utls.RoundTrip(req)
}
}
return f.fallback.RoundTrip(req)
}
// NewUtlsHTTPClient creates an HTTP client using utls Chrome TLS fingerprint.
// Use this for Claude API requests to match real Claude Code's TLS behavior.
// Falls back to standard transport for non-HTTPS requests.
func NewUtlsHTTPClient(cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
var proxyURL string
if auth != nil {
proxyURL = strings.TrimSpace(auth.ProxyURL)
}
if proxyURL == "" && cfg != nil {
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
utlsRT := newUtlsRoundTripper(proxyURL)
var standardTransport http.RoundTripper = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
}
if proxyURL != "" {
if transport := buildProxyTransport(proxyURL); transport != nil {
standardTransport = transport
}
}
client := &http.Client{
Transport: &fallbackRoundTripper{
utls: utlsRT,
fallback: standardTransport,
},
}
if timeout > 0 {
client.Timeout = timeout
}
return client
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/uuid"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -66,7 +67,7 @@ func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -86,8 +87,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
baseURL = iflowauth.DefaultAPIBaseURL
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
@@ -106,8 +107,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
}
body = preserveReasoningContentInMessages(body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -116,13 +117,18 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
return resp, err
}
applyIFlowHeaders(httpReq, apiKey, false)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -134,10 +140,10 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -145,25 +151,25 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
log.Errorf("iflow executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
// Ensure usage is recorded even if upstream omits usage metadata.
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
@@ -189,8 +195,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
baseURL = iflowauth.DefaultAPIBaseURL
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
@@ -214,8 +220,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -224,13 +230,18 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err
}
applyIFlowHeaders(httpReq, apiKey, true)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -242,21 +253,21 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, _ := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("iflow executor: close response body error: %v", errClose)
}
appendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
return nil, err
}
@@ -275,9 +286,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
@@ -285,12 +296,12 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
// Guarantee a usage record exists even if the stream never emitted usage data.
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -303,17 +314,17 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
enc, err := tokenizerForModel(baseModel)
enc, err := helps.TokenizerForModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, body)
count, err := helps.CountOpenAIChatTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
usageJSON := helps.BuildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translated}, nil
}

View File

@@ -15,7 +15,9 @@ import (
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -45,6 +47,11 @@ func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth
if strings.TrimSpace(token) != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
@@ -60,7 +67,7 @@ func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -76,8 +83,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
token := kimiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
@@ -100,8 +107,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = normalizeKimiToolMessageLinks(body)
if err != nil {
return resp, err
@@ -113,13 +120,18 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
applyKimiHeadersWithAuth(httpReq, token, false, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -131,10 +143,10 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -142,21 +154,21 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
log.Errorf("kimi executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
@@ -176,8 +188,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
baseModel := thinking.ParseSuffix(req.Model).ModelName
token := kimiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
@@ -204,8 +216,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if err != nil {
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = normalizeKimiToolMessageLinks(body)
if err != nil {
return nil, err
@@ -217,13 +229,18 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
applyKimiHeadersWithAuth(httpReq, token, true, auth)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -235,17 +252,17 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("kimi executor: close response body error: %v", errClose)
}
@@ -265,9 +282,9 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
@@ -279,8 +296,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -65,15 +66,15 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
@@ -95,8 +96,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
if opts.Alt == "responses/compact" {
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
translated = updated
@@ -129,7 +130,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -141,10 +142,10 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -152,23 +153,23 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
// Ensure we at least record the request even if upstream doesn't return usage
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
// Translate response back to source format when needed
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
@@ -179,8 +180,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
@@ -197,8 +198,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
@@ -232,7 +233,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -244,17 +245,17 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
@@ -274,9 +275,9 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if len(line) == 0 {
continue
@@ -294,12 +295,12 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
// Ensure we record the request if no usage chunk was ever seen
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
@@ -318,17 +319,17 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
return cliproxyexecutor.Response{}, err
}
enc, err := tokenizerForModel(modelForCounting)
enc, err := helps.TokenizerForModel(modelForCounting)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, translated)
count, err := helps.CountOpenAIChatTokens(enc, translated)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
usageJSON := helps.BuildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
}

View File

@@ -13,7 +13,9 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -23,7 +25,7 @@ import (
)
const (
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenUserAgent = "QwenCode/0.13.2 (darwin; arm64)"
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
qwenRateLimitWindow = time.Minute // sliding window duration
)
@@ -154,7 +156,7 @@ func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int,
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
cooldown := timeUntilNextDay()
retryAfter = &cooldown
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
helps.LogWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
}
return errCode, retryAfter
}
@@ -202,7 +204,7 @@ func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -217,7 +219,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return resp, err
}
@@ -228,8 +230,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
baseURL = "https://portal.qwen.ai/v1"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
@@ -247,8 +249,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -256,12 +258,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
applyQwenHeaders(httpReq, token, false)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authLabel, authType, authValue string
if auth != nil {
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -273,10 +280,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -284,23 +291,23 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
log.Errorf("qwen executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
@@ -320,7 +327,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return nil, err
}
@@ -331,8 +338,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
baseURL = "https://portal.qwen.ai/v1"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
@@ -357,8 +364,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
}
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -366,12 +373,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
applyQwenHeaders(httpReq, token, true)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authLabel, authType, authValue string
if auth != nil {
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -383,19 +395,19 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
@@ -415,9 +427,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
@@ -429,8 +441,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -449,17 +461,17 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
modelName = baseModel
}
enc, err := tokenizerForModel(modelName)
enc, err := helps.TokenizerForModel(modelName)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, body)
count, err := helps.CountOpenAIChatTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
usageJSON := helps.BuildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translated}, nil
}
@@ -508,16 +520,15 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", qwenUserAgent)
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
r.Header["X-DashScope-UserAgent"] = []string{qwenUserAgent}
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
r.Header.Set("Sec-Fetch-Mode", "cors")
r.Header.Set("X-Stainless-Lang", "js")
r.Header.Set("X-Stainless-Arch", "arm64")
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
r.Header["X-DashScope-CacheControl"] = []string{"enable"}
r.Header.Set("X-Stainless-Retry-Count", "0")
r.Header.Set("X-Stainless-Os", "MacOS")
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
r.Header["X-DashScope-AuthType"] = []string{"qwen-oauth"}
r.Header.Set("X-Stainless-Runtime", "node")
if stream {

View File

@@ -446,6 +446,7 @@ func (s *GitTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
if email, ok := metadata["email"].(string); ok && email != "" {
auth.Attributes["email"] = email
}
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
return auth, nil
}

View File

@@ -595,6 +595,7 @@ func (s *ObjectTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Aut
LastRefreshedAt: time.Time{},
NextRefreshAfter: time.Time{},
}
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
return auth, nil
}

View File

@@ -310,6 +310,7 @@ func (s *PostgresStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error)
LastRefreshedAt: time.Time{},
NextRefreshAfter: time.Time{},
}
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
auths = append(auths, auth)
}
if err = rows.Err(); err != nil {

View File

@@ -330,32 +330,45 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
}
// Reorder parts for 'model' role to ensure thinking block is first
// Reorder parts for 'model' role:
// 1. Thinking parts first (Antigravity API requirement)
// 2. Regular parts (text, inlineData, etc.)
// 3. FunctionCall parts last
//
// Moving functionCall parts to the end prevents tool_use↔tool_result
// pairing breakage: the Antigravity API internally splits model messages
// at functionCall boundaries. If a text part follows a functionCall, the
// split creates an extra assistant turn between tool_use and tool_result,
// which Claude rejects with "tool_use ids were found without tool_result
// blocks immediately after".
if role == "model" {
partsResult := gjson.GetBytes(clientContentJSON, "parts")
if partsResult.IsArray() {
parts := partsResult.Array()
var thinkingParts []gjson.Result
var otherParts []gjson.Result
for _, part := range parts {
if part.Get("thought").Bool() {
thinkingParts = append(thinkingParts, part)
} else {
otherParts = append(otherParts, part)
}
}
if len(thinkingParts) > 0 {
firstPartIsThinking := parts[0].Get("thought").Bool()
if !firstPartIsThinking || len(thinkingParts) > 1 {
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
if len(parts) > 1 {
var thinkingParts []gjson.Result
var regularParts []gjson.Result
var functionCallParts []gjson.Result
for _, part := range parts {
if part.Get("thought").Bool() {
thinkingParts = append(thinkingParts, part)
} else if part.Get("functionCall").Exists() {
functionCallParts = append(functionCallParts, part)
} else {
regularParts = append(regularParts, part)
}
for _, p := range otherParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
}
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
}
for _, p := range regularParts {
newParts = append(newParts, p.Value())
}
for _, p := range functionCallParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
}
}
}

View File

@@ -361,6 +361,167 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
}
}
func TestConvertClaudeRequestToAntigravity_ReorderTextAfterFunctionCall(t *testing.T) {
// Bug: text part after tool_use in an assistant message causes Antigravity
// to split at functionCall boundary, creating an extra assistant turn that
// breaks tool_use↔tool_result adjacency (upstream issue #989).
// Fix: reorder parts so functionCall comes last.
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me check..."},
{
"type": "tool_use",
"id": "call_abc",
"name": "Read",
"input": {"file": "test.go"}
},
{"type": "text", "text": "Reading the file now"}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_abc",
"content": "file content"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 3 {
t.Fatalf("Expected 3 parts, got %d", len(parts))
}
// Text parts should come before functionCall
if parts[0].Get("text").String() != "Let me check..." {
t.Errorf("Expected first text part first, got %s", parts[0].Raw)
}
if parts[1].Get("text").String() != "Reading the file now" {
t.Errorf("Expected second text part second, got %s", parts[1].Raw)
}
if !parts[2].Get("functionCall").Exists() {
t.Errorf("Expected functionCall last, got %s", parts[2].Raw)
}
if parts[2].Get("functionCall.name").String() != "Read" {
t.Errorf("Expected functionCall name 'Read', got '%s'", parts[2].Get("functionCall.name").String())
}
}
func TestConvertClaudeRequestToAntigravity_ReorderParallelFunctionCalls(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Reading both files."},
{
"type": "tool_use",
"id": "call_1",
"name": "Read",
"input": {"file": "a.go"}
},
{"type": "text", "text": "And this one too."},
{
"type": "tool_use",
"id": "call_2",
"name": "Read",
"input": {"file": "b.go"}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 4 {
t.Fatalf("Expected 4 parts, got %d", len(parts))
}
if parts[0].Get("text").String() != "Reading both files." {
t.Errorf("Expected first text, got %s", parts[0].Raw)
}
if parts[1].Get("text").String() != "And this one too." {
t.Errorf("Expected second text, got %s", parts[1].Raw)
}
if parts[2].Get("functionCall.name").String() != "Read" || parts[2].Get("functionCall.id").String() != "call_1" {
t.Errorf("Expected fc1 third, got %s", parts[2].Raw)
}
if parts[3].Get("functionCall.name").String() != "Read" || parts[3].Get("functionCall.id").String() != "call_2" {
t.Errorf("Expected fc2 fourth, got %s", parts[3].Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ReorderThinkingAndTextBeforeFunctionCall(t *testing.T) {
cache.ClearSignatureCache("")
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think about this..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Before thinking"},
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
{
"type": "tool_use",
"id": "call_xyz",
"name": "Bash",
"input": {"command": "ls"}
},
{"type": "text", "text": "After tool call"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// contents.1 = assistant message (contents.0 = user)
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
if len(parts) != 4 {
t.Fatalf("Expected 4 parts, got %d", len(parts))
}
// Order: thinking → text → text → functionCall
if !parts[0].Get("thought").Bool() {
t.Error("First part should be thinking")
}
if parts[1].Get("functionCall").Exists() || parts[1].Get("thought").Bool() {
t.Errorf("Second part should be text, got %s", parts[1].Raw)
}
if parts[2].Get("functionCall").Exists() || parts[2].Get("thought").Bool() {
t.Errorf("Third part should be text, got %s", parts[2].Raw)
}
if !parts[3].Get("functionCall").Exists() {
t.Errorf("Last part should be functionCall, got %s", parts[3].Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",

View File

@@ -165,29 +165,22 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
// Process messages and transform them to Claude Code format
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
messageIndex := 0
systemMessageIndex := -1
messages.ForEach(func(_, message gjson.Result) bool {
role := message.Get("role").String()
contentResult := message.Get("content")
switch role {
case "system":
if systemMessageIndex == -1 {
systemMsg := []byte(`{"role":"user","content":[]}`)
out, _ = sjson.SetRawBytes(out, "messages.-1", systemMsg)
systemMessageIndex = messageIndex
messageIndex++
}
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", contentResult.String())
out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
out, _ = sjson.SetRawBytes(out, "system.-1", textPart)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
out, _ = sjson.SetRawBytes(out, "system.-1", textPart)
}
return true
})
@@ -269,6 +262,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
}
return true
})
// Preserve a minimal conversational turn for system-only inputs.
// Claude payloads with top-level system instructions but no messages are risky for downstream validation.
if messageIndex == 0 {
system := gjson.GetBytes(out, "system")
if system.Exists() && system.IsArray() && len(system.Array()) > 0 {
fallbackMsg := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
out, _ = sjson.SetRawBytes(out, "messages.-1", fallbackMsg)
}
}
}
// Tools mapping: OpenAI tools -> Claude Code tools

View File

@@ -135,3 +135,111 @@ func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
t.Fatalf("Unexpected image URL: %q", got)
}
}
func TestConvertOpenAIRequestToClaude_SystemRoleBecomesTopLevelSystem(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system")
if !system.IsArray() {
t.Fatalf("Expected top-level system array, got %s", system.Raw)
}
if len(system.Array()) != 1 {
t.Fatalf("Expected 1 system block, got %d. System: %s", len(system.Array()), system.Raw)
}
if got := system.Get("0.type").String(); got != "text" {
t.Fatalf("Expected system block type %q, got %q", "text", got)
}
if got := system.Get("0.text").String(); got != "You are a helpful assistant." {
t.Fatalf("Expected system text %q, got %q", "You are a helpful assistant.", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 non-system message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected remaining message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.text").String(); got != "Hello" {
t.Fatalf("Expected user text %q, got %q", "Hello", got)
}
}
func TestConvertOpenAIRequestToClaude_MultipleSystemMessagesMergedIntoTopLevelSystem(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "Rule 1"},
{"role": "system", "content": [{"type": "text", "text": "Rule 2"}]},
{"role": "user", "content": "Hello"}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system").Array()
if len(system) != 2 {
t.Fatalf("Expected 2 system blocks, got %d. System: %s", len(system), resultJSON.Get("system").Raw)
}
if got := system[0].Get("text").String(); got != "Rule 1" {
t.Fatalf("Expected first system text %q, got %q", "Rule 1", got)
}
if got := system[1].Get("text").String(); got != "Rule 2" {
t.Fatalf("Expected second system text %q, got %q", "Rule 2", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 non-system message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected remaining message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.text").String(); got != "Hello" {
t.Fatalf("Expected user text %q, got %q", "Hello", got)
}
}
func TestConvertOpenAIRequestToClaude_SystemOnlyInputKeepsFallbackUserMessage(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system").Array()
if len(system) != 1 {
t.Fatalf("Expected 1 system block, got %d. System: %s", len(system), resultJSON.Get("system").Raw)
}
if got := system[0].Get("text").String(); got != "You are a helpful assistant." {
t.Fatalf("Expected system text %q, got %q", "You are a helpful assistant.", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 fallback message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected fallback message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.type").String(); got != "text" {
t.Fatalf("Expected fallback content type %q, got %q", "text", got)
}
if got := messages[0].Get("content.0.text").String(); got != "" {
t.Fatalf("Expected fallback text %q, got %q", "", got)
}
}

View File

@@ -284,12 +284,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
}
// Process the output array for content and function calls
var toolCalls [][]byte
outputResult := responseResult.Get("output")
if outputResult.IsArray() {
outputArray := outputResult.Array()
var contentText string
var reasoningText string
var toolCalls [][]byte
for _, outputItem := range outputArray {
outputType := outputItem.Get("type").String()
@@ -367,8 +367,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if statusResult := responseResult.Get("status"); statusResult.Exists() {
status := statusResult.String()
if status == "completed" {
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"sort"
"strings"
"sync/atomic"
"time"
@@ -16,6 +17,7 @@ import (
type oaiToResponsesStateReasoning struct {
ReasoningID string
ReasoningData string
OutputIndex int
}
type oaiToResponsesState struct {
Seq int
@@ -29,16 +31,19 @@ type oaiToResponsesState struct {
MsgTextBuf map[int]*strings.Builder
ReasoningBuf strings.Builder
Reasonings []oaiToResponsesStateReasoning
FuncArgsBuf map[int]*strings.Builder // index -> args
FuncNames map[int]string // index -> name
FuncCallIDs map[int]string // index -> call_id
FuncArgsBuf map[string]*strings.Builder
FuncNames map[string]string
FuncCallIDs map[string]string
FuncOutputIx map[string]int
MsgOutputIx map[int]int
NextOutputIx int
// message item state per output index
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
MsgItemDone map[int]bool // whether message done events were emitted
// function item done state
FuncArgsDone map[int]bool
FuncItemDone map[int]bool
FuncArgsDone map[string]bool
FuncItemDone map[string]bool
// usage aggregation
PromptTokens int64
CachedTokens int64
@@ -60,15 +65,17 @@ func emitRespEvent(event string, payload []byte) []byte {
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &oaiToResponsesState{
FuncArgsBuf: make(map[int]*strings.Builder),
FuncNames: make(map[int]string),
FuncCallIDs: make(map[int]string),
FuncArgsBuf: make(map[string]*strings.Builder),
FuncNames: make(map[string]string),
FuncCallIDs: make(map[string]string),
FuncOutputIx: make(map[string]int),
MsgOutputIx: make(map[int]int),
MsgTextBuf: make(map[int]*strings.Builder),
MsgItemAdded: make(map[int]bool),
MsgContentAdded: make(map[int]bool),
MsgItemDone: make(map[int]bool),
FuncArgsDone: make(map[int]bool),
FuncItemDone: make(map[int]bool),
FuncArgsDone: make(map[string]bool),
FuncItemDone: make(map[string]bool),
Reasonings: make([]oaiToResponsesStateReasoning, 0),
}
}
@@ -125,6 +132,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
}
nextSeq := func() int { st.Seq++; return st.Seq }
allocOutputIndex := func() int {
ix := st.NextOutputIx
st.NextOutputIx++
return ix
}
toolStateKey := func(outputIndex, toolIndex int) string { return fmt.Sprintf("%d:%d", outputIndex, toolIndex) }
var out [][]byte
if !st.Started {
@@ -135,14 +148,17 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
st.ReasoningBuf.Reset()
st.ReasoningID = ""
st.ReasoningIndex = 0
st.FuncArgsBuf = make(map[int]*strings.Builder)
st.FuncNames = make(map[int]string)
st.FuncCallIDs = make(map[int]string)
st.FuncArgsBuf = make(map[string]*strings.Builder)
st.FuncNames = make(map[string]string)
st.FuncCallIDs = make(map[string]string)
st.FuncOutputIx = make(map[string]int)
st.MsgOutputIx = make(map[int]int)
st.NextOutputIx = 0
st.MsgItemAdded = make(map[int]bool)
st.MsgContentAdded = make(map[int]bool)
st.MsgItemDone = make(map[int]bool)
st.FuncArgsDone = make(map[int]bool)
st.FuncItemDone = make(map[int]bool)
st.FuncArgsDone = make(map[string]bool)
st.FuncItemDone = make(map[string]bool)
st.PromptTokens = 0
st.CachedTokens = 0
st.CompletionTokens = 0
@@ -185,7 +201,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text)
out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text})
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text, OutputIndex: st.ReasoningIndex})
st.ReasoningID = ""
}
@@ -201,10 +217,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
stopReasoning(st.ReasoningBuf.String())
st.ReasoningBuf.Reset()
}
if _, exists := st.MsgOutputIx[idx]; !exists {
st.MsgOutputIx[idx] = allocOutputIndex()
}
msgOutputIndex := st.MsgOutputIx[idx]
if !st.MsgItemAdded[idx] {
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "output_index", idx)
item, _ = sjson.SetBytes(item, "output_index", msgOutputIndex)
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
out = append(out, emitRespEvent("response.output_item.added", item))
st.MsgItemAdded[idx] = true
@@ -213,7 +233,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
part, _ = sjson.SetBytes(part, "output_index", idx)
part, _ = sjson.SetBytes(part, "output_index", msgOutputIndex)
part, _ = sjson.SetBytes(part, "content_index", 0)
out = append(out, emitRespEvent("response.content_part.added", part))
st.MsgContentAdded[idx] = true
@@ -222,7 +242,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
msg, _ = sjson.SetBytes(msg, "output_index", idx)
msg, _ = sjson.SetBytes(msg, "output_index", msgOutputIndex)
msg, _ = sjson.SetBytes(msg, "content_index", 0)
msg, _ = sjson.SetBytes(msg, "delta", c.String())
out = append(out, emitRespEvent("response.output_text.delta", msg))
@@ -238,10 +258,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
// On first appearance, add reasoning item and part
if st.ReasoningID == "" {
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
st.ReasoningIndex = idx
st.ReasoningIndex = allocOutputIndex()
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "output_index", idx)
item, _ = sjson.SetBytes(item, "output_index", st.ReasoningIndex)
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID)
out = append(out, emitRespEvent("response.output_item.added", item))
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
@@ -269,6 +289,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
// Before emitting any function events, if a message is open for this index,
// close its text/content to match Codex expected ordering.
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
msgOutputIndex := st.MsgOutputIx[idx]
fullText := ""
if b := st.MsgTextBuf[idx]; b != nil {
fullText = b.String()
@@ -276,7 +297,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
done, _ = sjson.SetBytes(done, "output_index", idx)
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
done, _ = sjson.SetBytes(done, "content_index", 0)
done, _ = sjson.SetBytes(done, "text", fullText)
out = append(out, emitRespEvent("response.output_text.done", done))
@@ -284,69 +305,72 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
partDone, _ = sjson.SetBytes(partDone, "output_index", idx)
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
out = append(out, emitRespEvent("response.content_part.done", partDone))
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
out = append(out, emitRespEvent("response.output_item.done", itemDone))
st.MsgItemDone[idx] = true
}
// Only emit item.added once per tool call and preserve call_id across chunks.
newCallID := tcs.Get("0.id").String()
nameChunk := tcs.Get("0.function.name").String()
if nameChunk != "" {
st.FuncNames[idx] = nameChunk
}
existingCallID := st.FuncCallIDs[idx]
effectiveCallID := existingCallID
shouldEmitItem := false
if existingCallID == "" && newCallID != "" {
// First time seeing a valid call_id for this index
effectiveCallID = newCallID
st.FuncCallIDs[idx] = newCallID
shouldEmitItem = true
}
if shouldEmitItem && effectiveCallID != "" {
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
o, _ = sjson.SetBytes(o, "output_index", idx)
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
name := st.FuncNames[idx]
o, _ = sjson.SetBytes(o, "item.name", name)
out = append(out, emitRespEvent("response.output_item.added", o))
}
// Ensure args buffer exists for this index
if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{}
}
// Append arguments delta if available and we have a valid call_id to reference
if args := tcs.Get("0.function.arguments"); args.Exists() && args.String() != "" {
// Prefer an already known call_id; fall back to newCallID if first time
refCallID := st.FuncCallIDs[idx]
if refCallID == "" {
refCallID = newCallID
tcs.ForEach(func(_, tc gjson.Result) bool {
toolIndex := int(tc.Get("index").Int())
key := toolStateKey(idx, toolIndex)
newCallID := tc.Get("id").String()
nameChunk := tc.Get("function.name").String()
if nameChunk != "" {
st.FuncNames[key] = nameChunk
}
if refCallID != "" {
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
ad, _ = sjson.SetBytes(ad, "output_index", idx)
ad, _ = sjson.SetBytes(ad, "delta", args.String())
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
existingCallID := st.FuncCallIDs[key]
effectiveCallID := existingCallID
shouldEmitItem := false
if existingCallID == "" && newCallID != "" {
effectiveCallID = newCallID
st.FuncCallIDs[key] = newCallID
st.FuncOutputIx[key] = allocOutputIndex()
shouldEmitItem = true
}
st.FuncArgsBuf[idx].WriteString(args.String())
}
if shouldEmitItem && effectiveCallID != "" {
outputIndex := st.FuncOutputIx[key]
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
o, _ = sjson.SetBytes(o, "output_index", outputIndex)
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
o, _ = sjson.SetBytes(o, "item.name", st.FuncNames[key])
out = append(out, emitRespEvent("response.output_item.added", o))
}
if st.FuncArgsBuf[key] == nil {
st.FuncArgsBuf[key] = &strings.Builder{}
}
if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" {
refCallID := st.FuncCallIDs[key]
if refCallID == "" {
refCallID = newCallID
}
if refCallID != "" {
outputIndex := st.FuncOutputIx[key]
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
ad, _ = sjson.SetBytes(ad, "output_index", outputIndex)
ad, _ = sjson.SetBytes(ad, "delta", args.String())
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
}
st.FuncArgsBuf[key].WriteString(args.String())
}
return true
})
}
}
@@ -360,15 +384,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
for i := range st.MsgItemAdded {
idxs = append(idxs, i)
}
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
sort.Slice(idxs, func(i, j int) bool { return st.MsgOutputIx[idxs[i]] < st.MsgOutputIx[idxs[j]] })
for _, i := range idxs {
if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
msgOutputIndex := st.MsgOutputIx[i]
fullText := ""
if b := st.MsgTextBuf[i]; b != nil {
fullText = b.String()
@@ -376,7 +395,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
done, _ = sjson.SetBytes(done, "output_index", i)
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
done, _ = sjson.SetBytes(done, "content_index", 0)
done, _ = sjson.SetBytes(done, "text", fullText)
out = append(out, emitRespEvent("response.output_text.done", done))
@@ -384,14 +403,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
partDone, _ = sjson.SetBytes(partDone, "output_index", i)
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
out = append(out, emitRespEvent("response.content_part.done", partDone))
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
out = append(out, emitRespEvent("response.output_item.done", itemDone))
@@ -407,43 +426,42 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
// Emit function call done events for any active function calls
if len(st.FuncCallIDs) > 0 {
idxs := make([]int, 0, len(st.FuncCallIDs))
for i := range st.FuncCallIDs {
idxs = append(idxs, i)
keys := make([]string, 0, len(st.FuncCallIDs))
for key := range st.FuncCallIDs {
keys = append(keys, key)
}
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, i := range idxs {
callID := st.FuncCallIDs[i]
if callID == "" || st.FuncItemDone[i] {
sort.Slice(keys, func(i, j int) bool {
left := st.FuncOutputIx[keys[i]]
right := st.FuncOutputIx[keys[j]]
return left < right || (left == right && keys[i] < keys[j])
})
for _, key := range keys {
callID := st.FuncCallIDs[key]
if callID == "" || st.FuncItemDone[key] {
continue
}
outputIndex := st.FuncOutputIx[key]
args := "{}"
if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 {
if b := st.FuncArgsBuf[key]; b != nil && b.Len() > 0 {
args = b.String()
}
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
fcDone, _ = sjson.SetBytes(fcDone, "output_index", i)
fcDone, _ = sjson.SetBytes(fcDone, "output_index", outputIndex)
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
itemDone, _ = sjson.SetBytes(itemDone, "output_index", outputIndex)
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID)
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[i])
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[key])
out = append(out, emitRespEvent("response.output_item.done", itemDone))
st.FuncItemDone[i] = true
st.FuncArgsDone[i] = true
st.FuncItemDone[key] = true
st.FuncArgsDone[key] = true
}
}
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
@@ -516,28 +534,21 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
}
// Build response.output using aggregated buffers
outputsWrapper := []byte(`{"arr":[]}`)
type completedOutputItem struct {
index int
raw []byte
}
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
if len(st.Reasonings) > 0 {
for _, r := range st.Reasonings {
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
}
}
// Append message items in ascending index order
if len(st.MsgItemAdded) > 0 {
midxs := make([]int, 0, len(st.MsgItemAdded))
for i := range st.MsgItemAdded {
midxs = append(midxs, i)
}
for i := 0; i < len(midxs); i++ {
for j := i + 1; j < len(midxs); j++ {
if midxs[j] < midxs[i] {
midxs[i], midxs[j] = midxs[j], midxs[i]
}
}
}
for _, i := range midxs {
txt := ""
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
@@ -545,37 +556,29 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.SetBytes(item, "content.0.text", txt)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
}
}
if len(st.FuncArgsBuf) > 0 {
idxs := make([]int, 0, len(st.FuncArgsBuf))
for i := range st.FuncArgsBuf {
idxs = append(idxs, i)
}
// small-N sort without extra imports
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, i := range idxs {
for key := range st.FuncArgsBuf {
args := ""
if b := st.FuncArgsBuf[i]; b != nil {
if b := st.FuncArgsBuf[key]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[i]
name := st.FuncNames[i]
callID := st.FuncCallIDs[key]
name := st.FuncNames[key]
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
}
}
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
for _, item := range outputItems {
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
}
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}

View File

@@ -0,0 +1,305 @@
package responses
import (
"context"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Result) {
t.Helper()
lines := strings.Split(string(chunk), "\n")
if len(lines) < 2 {
t.Fatalf("unexpected SSE chunk: %q", chunk)
}
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
if !gjson.Valid(dataLine) {
t.Fatalf("invalid SSE data JSON: %q", dataLine)
}
return event, gjson.Parse(dataLine)
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
in := []string{
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\",\"limit\":400,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
var param any
var out [][]byte
for _, line := range in {
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param)...)
}
addedNames := map[string]string{}
doneArgs := map[string]string{}
doneNames := map[string]string{}
outputItems := map[string]gjson.Result{}
for _, chunk := range out {
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
switch ev {
case "response.output_item.added":
if data.Get("item.type").String() != "function_call" {
continue
}
addedNames[data.Get("item.call_id").String()] = data.Get("item.name").String()
case "response.output_item.done":
if data.Get("item.type").String() != "function_call" {
continue
}
callID := data.Get("item.call_id").String()
doneArgs[callID] = data.Get("item.arguments").String()
doneNames[callID] = data.Get("item.name").String()
case "response.completed":
output := data.Get("response.output")
for _, item := range output.Array() {
if item.Get("type").String() == "function_call" {
outputItems[item.Get("call_id").String()] = item
}
}
}
}
if len(addedNames) != 2 {
t.Fatalf("expected 2 function_call added events, got %d", len(addedNames))
}
if len(doneArgs) != 2 {
t.Fatalf("expected 2 function_call done events, got %d", len(doneArgs))
}
if addedNames["call_read"] != "read" {
t.Fatalf("unexpected added name for call_read: %q", addedNames["call_read"])
}
if addedNames["call_glob"] != "glob" {
t.Fatalf("unexpected added name for call_glob: %q", addedNames["call_glob"])
}
if !gjson.Valid(doneArgs["call_read"]) {
t.Fatalf("invalid JSON args for call_read: %q", doneArgs["call_read"])
}
if !gjson.Valid(doneArgs["call_glob"]) {
t.Fatalf("invalid JSON args for call_glob: %q", doneArgs["call_glob"])
}
if strings.Contains(doneArgs["call_read"], "}{") {
t.Fatalf("call_read args were concatenated: %q", doneArgs["call_read"])
}
if strings.Contains(doneArgs["call_glob"], "}{") {
t.Fatalf("call_glob args were concatenated: %q", doneArgs["call_glob"])
}
if doneNames["call_read"] != "read" {
t.Fatalf("unexpected done name for call_read: %q", doneNames["call_read"])
}
if doneNames["call_glob"] != "glob" {
t.Fatalf("unexpected done name for call_glob: %q", doneNames["call_glob"])
}
if got := gjson.Get(doneArgs["call_read"], "filePath").String(); got != `C:\repo` {
t.Fatalf("unexpected filePath for call_read: %q", got)
}
if got := gjson.Get(doneArgs["call_glob"], "path").String(); got != `C:\repo` {
t.Fatalf("unexpected path for call_glob: %q", got)
}
if got := gjson.Get(doneArgs["call_glob"], "pattern").String(); got != "*.{yml,yaml}" {
t.Fatalf("unexpected pattern for call_glob: %q", got)
}
if len(outputItems) != 2 {
t.Fatalf("expected 2 function_call items in response.output, got %d", len(outputItems))
}
if outputItems["call_read"].Get("name").String() != "read" {
t.Fatalf("unexpected response.output name for call_read: %q", outputItems["call_read"].Get("name").String())
}
if outputItems["call_glob"].Get("name").String() != "glob" {
t.Fatalf("unexpected response.output name for call_glob: %q", outputItems["call_glob"].Get("name").String())
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCallsUseDistinctOutputIndexes(t *testing.T) {
in := []string{
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
var param any
var out [][]byte
for _, line := range in {
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param)...)
}
type fcEvent struct {
outputIndex int64
name string
arguments string
}
added := map[string]fcEvent{}
done := map[string]fcEvent{}
for _, chunk := range out {
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
switch ev {
case "response.output_item.added":
if data.Get("item.type").String() != "function_call" {
continue
}
callID := data.Get("item.call_id").String()
added[callID] = fcEvent{
outputIndex: data.Get("output_index").Int(),
name: data.Get("item.name").String(),
}
case "response.output_item.done":
if data.Get("item.type").String() != "function_call" {
continue
}
callID := data.Get("item.call_id").String()
done[callID] = fcEvent{
outputIndex: data.Get("output_index").Int(),
name: data.Get("item.name").String(),
arguments: data.Get("item.arguments").String(),
}
}
}
if len(added) != 2 {
t.Fatalf("expected 2 function_call added events, got %d", len(added))
}
if len(done) != 2 {
t.Fatalf("expected 2 function_call done events, got %d", len(done))
}
if added["call_choice0"].name != "glob" {
t.Fatalf("unexpected added name for call_choice0: %q", added["call_choice0"].name)
}
if added["call_choice1"].name != "read" {
t.Fatalf("unexpected added name for call_choice1: %q", added["call_choice1"].name)
}
if added["call_choice0"].outputIndex == added["call_choice1"].outputIndex {
t.Fatalf("expected distinct output indexes for different choices, both got %d", added["call_choice0"].outputIndex)
}
if !gjson.Valid(done["call_choice0"].arguments) {
t.Fatalf("invalid JSON args for call_choice0: %q", done["call_choice0"].arguments)
}
if !gjson.Valid(done["call_choice1"].arguments) {
t.Fatalf("invalid JSON args for call_choice1: %q", done["call_choice1"].arguments)
}
if done["call_choice0"].outputIndex == done["call_choice1"].outputIndex {
t.Fatalf("expected distinct done output indexes for different choices, both got %d", done["call_choice0"].outputIndex)
}
if done["call_choice0"].name != "glob" {
t.Fatalf("unexpected done name for call_choice0: %q", done["call_choice0"].name)
}
if done["call_choice1"].name != "read" {
t.Fatalf("unexpected done name for call_choice1: %q", done["call_choice1"].name)
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndToolUseDistinctOutputIndexes(t *testing.T) {
in := []string{
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
var param any
var out [][]byte
for _, line := range in {
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param)...)
}
var messageOutputIndex int64 = -1
var toolOutputIndex int64 = -1
for _, chunk := range out {
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
if ev != "response.output_item.added" {
continue
}
switch data.Get("item.type").String() {
case "message":
if data.Get("item.id").String() == "msg_resp_mixed_0" {
messageOutputIndex = data.Get("output_index").Int()
}
case "function_call":
if data.Get("item.call_id").String() == "call_choice1" {
toolOutputIndex = data.Get("output_index").Int()
}
}
}
if messageOutputIndex < 0 {
t.Fatal("did not find message output index")
}
if toolOutputIndex < 0 {
t.Fatal("did not find tool output index")
}
if messageOutputIndex == toolOutputIndex {
t.Fatalf("expected distinct output indexes for message and tool call, both got %d", messageOutputIndex)
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneAndCompletedOutputStayAscending(t *testing.T) {
in := []string{
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
var param any
var out [][]byte
for _, line := range in {
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param)...)
}
var doneIndexes []int64
var completedOrder []string
for _, chunk := range out {
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
switch ev {
case "response.output_item.done":
if data.Get("item.type").String() == "function_call" {
doneIndexes = append(doneIndexes, data.Get("output_index").Int())
}
case "response.completed":
for _, item := range data.Get("response.output").Array() {
if item.Get("type").String() == "function_call" {
completedOrder = append(completedOrder, item.Get("call_id").String())
}
}
}
}
if len(doneIndexes) != 2 {
t.Fatalf("expected 2 function_call done indexes, got %d", len(doneIndexes))
}
if doneIndexes[0] >= doneIndexes[1] {
t.Fatalf("expected ascending done output indexes, got %v", doneIndexes)
}
if len(completedOrder) != 2 {
t.Fatalf("expected 2 function_call items in completed output, got %d", len(completedOrder))
}
if completedOrder[0] != "call_glob" || completedOrder[1] != "call_read" {
t.Fatalf("unexpected completed function_call order: %v", completedOrder)
}
}

View File

@@ -201,6 +201,7 @@ var zhStrings = map[string]string{
"usage_output": "输出",
"usage_cached": "缓存",
"usage_reasoning": "思考",
"usage_time": "时间",
// ── Logs ──
"logs_title": "📋 日志",
@@ -352,6 +353,7 @@ var enStrings = map[string]string{
"usage_output": "Output",
"usage_cached": "Cached",
"usage_reasoning": "Reasoning",
"usage_time": "Time",
// ── Logs ──
"logs_title": "📋 Logs",

View File

@@ -248,6 +248,9 @@ func (m usageTabModel) renderContent() string {
// Token type breakdown from details
sb.WriteString(m.renderTokenBreakdown(stats))
// Latency breakdown from details
sb.WriteString(m.renderLatencyBreakdown(stats))
}
}
}
@@ -308,6 +311,57 @@ func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
}
// renderLatencyBreakdown aggregates latency_ms from model details and displays avg/min/max.
func (m usageTabModel) renderLatencyBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var totalLatency int64
var count int
var minLatency, maxLatency int64
first := true
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
latencyMs := int64(getFloat(dm, "latency_ms"))
if latencyMs <= 0 {
continue
}
totalLatency += latencyMs
count++
if first {
minLatency = latencyMs
maxLatency = latencyMs
first = false
} else {
if latencyMs < minLatency {
minLatency = latencyMs
}
if latencyMs > maxLatency {
maxLatency = latencyMs
}
}
}
if count == 0 {
return ""
}
avgLatency := totalLatency / int64(count)
return fmt.Sprintf(" │ %s: avg %dms min %dms max %dms\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_time")),
avgLatency, minLatency, maxLatency)
}
// renderBarChart renders a simple ASCII horizontal bar chart.
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
if maxBarWidth < 10 {

View File

@@ -0,0 +1,134 @@
package tui
import (
"strings"
"testing"
)
func TestRenderLatencyBreakdown(t *testing.T) {
tests := []struct {
name string
modelStats map[string]any
wantEmpty bool
wantContains string
}{
{
name: "no details",
modelStats: map[string]any{},
wantEmpty: true,
},
{
name: "empty details",
modelStats: map[string]any{
"details": []any{},
},
wantEmpty: true,
},
{
name: "details with zero latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(0),
},
},
},
wantEmpty: true,
},
{
name: "single request with latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(1500),
},
},
},
wantEmpty: false,
wantContains: "avg 1500ms min 1500ms max 1500ms",
},
{
name: "multiple requests with varying latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(100),
},
map[string]any{
"latency_ms": float64(200),
},
map[string]any{
"latency_ms": float64(300),
},
},
},
wantEmpty: false,
wantContains: "avg 200ms min 100ms max 300ms",
},
{
name: "mixed valid and invalid latency values",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(500),
},
map[string]any{
"latency_ms": float64(0),
},
map[string]any{
"latency_ms": float64(1500),
},
},
},
wantEmpty: false,
wantContains: "avg 1000ms min 500ms max 1500ms",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := usageTabModel{}
result := m.renderLatencyBreakdown(tt.modelStats)
if tt.wantEmpty {
if result != "" {
t.Errorf("renderLatencyBreakdown() = %q, want empty string", result)
}
return
}
if result == "" {
t.Errorf("renderLatencyBreakdown() = empty, want non-empty string")
return
}
if tt.wantContains != "" && !strings.Contains(result, tt.wantContains) {
t.Errorf("renderLatencyBreakdown() = %q, want to contain %q", result, tt.wantContains)
}
})
}
}
func TestUsageTimeTranslations(t *testing.T) {
prevLocale := CurrentLocale()
t.Cleanup(func() {
SetLocale(prevLocale)
})
tests := []struct {
locale string
want string
}{
{locale: "en", want: "Time"},
{locale: "zh", want: "时间"},
}
for _, tt := range tests {
t.Run(tt.locale, func(t *testing.T) {
SetLocale(tt.locale)
if got := T("usage_time"); got != tt.want {
t.Fatalf("T(usage_time) = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -57,6 +57,12 @@ func GetProviderName(modelName string) []string {
return providers
}
// Fallback: if cursor provider has registered models, route unknown models to it.
// Cursor acts as a universal proxy supporting multiple model families (Claude, GPT, Gemini, etc.).
if models := registry.GetGlobalRegistry().GetAvailableModelsByProvider("cursor"); len(models) > 0 {
return []string{"cursor"}
}
return providers
}

View File

@@ -80,6 +80,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
}
if oldCfg.QuotaExceeded.AntigravityCredits != newCfg.QuotaExceeded.AntigravityCredits {
changes = append(changes, fmt.Sprintf("quota-exceeded.antigravity-credits: %t -> %t", oldCfg.QuotaExceeded.AntigravityCredits, newCfg.QuotaExceeded.AntigravityCredits))
}
if oldCfg.Routing.Strategy != newCfg.Routing.Strategy {
changes = append(changes, fmt.Sprintf("routing.strategy: %s -> %s", oldCfg.Routing.Strategy, newCfg.Routing.Strategy))
@@ -256,6 +259,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel {
changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel))
}
if oldCfg.RemoteManagement.DisableAutoUpdatePanel != newCfg.RemoteManagement.DisableAutoUpdatePanel {
changes = append(changes, fmt.Sprintf("remote-management.disable-auto-update-panel: %t -> %t", oldCfg.RemoteManagement.DisableAutoUpdatePanel, newCfg.RemoteManagement.DisableAutoUpdatePanel))
}
oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository)
newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository)
if oldPanelRepo != newPanelRepo {

View File

@@ -20,10 +20,11 @@ func TestBuildConfigChangeDetails(t *testing.T) {
RestrictManagementToLocalhost: false,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: false,
SecretKey: "old",
DisableControlPanel: false,
PanelGitHubRepository: "repo-old",
AllowRemote: false,
SecretKey: "old",
DisableControlPanel: false,
DisableAutoUpdatePanel: false,
PanelGitHubRepository: "repo-old",
},
OAuthExcludedModels: map[string][]string{
"providerA": {"m1"},
@@ -54,10 +55,11 @@ func TestBuildConfigChangeDetails(t *testing.T) {
},
},
RemoteManagement: config.RemoteManagement{
AllowRemote: true,
SecretKey: "new",
DisableControlPanel: true,
PanelGitHubRepository: "repo-new",
AllowRemote: true,
SecretKey: "new",
DisableControlPanel: true,
DisableAutoUpdatePanel: true,
PanelGitHubRepository: "repo-new",
},
OAuthExcludedModels: map[string][]string{
"providerA": {"m1", "m2"},
@@ -88,6 +90,7 @@ func TestBuildConfigChangeDetails(t *testing.T) {
expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream")
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)")
expectContains(t, details, "remote-management.allow-remote: false -> true")
expectContains(t, details, "remote-management.disable-auto-update-panel: false -> true")
expectContains(t, details, "remote-management.secret-key: updated")
expectContains(t, details, "oauth-excluded-models[providera]: updated (1 -> 2 entries)")
expectContains(t, details, "oauth-excluded-models[providerb]: added (1 entries)")
@@ -226,7 +229,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
MaxRetryCredentials: 1,
MaxRetryInterval: 1,
WebsocketAuth: false,
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false, AntigravityCredits: false},
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
CodexKey: []config.CodexKey{{APIKey: "x1"}},
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
@@ -250,7 +253,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
MaxRetryCredentials: 3,
MaxRetryInterval: 3,
WebsocketAuth: true,
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true, AntigravityCredits: true},
ClaudeKey: []config.ClaudeKey{
{APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
{APIKey: "c2"},
@@ -265,9 +268,10 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
},
RemoteManagement: config.RemoteManagement{
DisableControlPanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
DisableControlPanel: true,
DisableAutoUpdatePanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: true,
@@ -293,12 +297,14 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
expectContains(t, details, "nonstream-keepalive-interval: 0 -> 5")
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
expectContains(t, details, "quota-exceeded.antigravity-credits: false -> true")
expectContains(t, details, "api-keys count: 1 -> 2")
expectContains(t, details, "claude-api-key count: 1 -> 2")
expectContains(t, details, "codex-api-key count: 1 -> 2")
expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true")
expectContains(t, details, "ampcode.upstream-api-key: removed")
expectContains(t, details, "remote-management.disable-control-panel: false -> true")
expectContains(t, details, "remote-management.disable-auto-update-panel: false -> true")
expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo")
expectContains(t, details, "remote-management.secret-key: deleted")
}
@@ -315,7 +321,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
MaxRetryCredentials: 1,
MaxRetryInterval: 1,
WebsocketAuth: false,
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false, AntigravityCredits: false},
GeminiKey: []config.GeminiKey{
{APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}},
},
@@ -336,10 +342,11 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
ForceModelMappings: false,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: false,
DisableControlPanel: false,
PanelGitHubRepository: "old/repo",
SecretKey: "old",
AllowRemote: false,
DisableControlPanel: false,
DisableAutoUpdatePanel: false,
PanelGitHubRepository: "old/repo",
SecretKey: "old",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: false,
@@ -368,7 +375,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
MaxRetryCredentials: 3,
MaxRetryInterval: 3,
WebsocketAuth: true,
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true, AntigravityCredits: true},
GeminiKey: []config.GeminiKey{
{APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}},
},
@@ -389,10 +396,11 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
ForceModelMappings: true,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: true,
DisableControlPanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
AllowRemote: true,
DisableControlPanel: true,
DisableAutoUpdatePanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: true,
@@ -430,6 +438,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
expectContains(t, changes, "ws-auth: false -> true")
expectContains(t, changes, "quota-exceeded.switch-project: false -> true")
expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true")
expectContains(t, changes, "quota-exceeded.antigravity-credits: false -> true")
expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)")
expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new")
expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new")
@@ -460,6 +469,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)")
expectContains(t, changes, "remote-management.allow-remote: false -> true")
expectContains(t, changes, "remote-management.disable-control-panel: false -> true")
expectContains(t, changes, "remote-management.disable-auto-update-panel: false -> true")
expectContains(t, changes, "remote-management.panel-github-repository: old/repo -> new/repo")
expectContains(t, changes, "remote-management.secret-key: deleted")
expectContains(t, changes, "openai-compatibility:")

View File

@@ -157,6 +157,7 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
}
}
}
coreauth.ApplyCustomHeadersFromMetadata(a)
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
// For codex auth files, extract plan_type from the JWT id_token.
if provider == "codex" {
@@ -233,6 +234,11 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
if noteVal, hasNote := primary.Attributes["note"]; hasNote && noteVal != "" {
attrs["note"] = noteVal
}
for k, v := range primary.Attributes {
if strings.HasPrefix(k, "header:") && strings.TrimSpace(v) != "" {
attrs[k] = v
}
}
metadataCopy := map[string]any{
"email": email,
"project_id": projectID,

View File

@@ -69,10 +69,14 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
// Create a valid auth file
authData := map[string]any{
"type": "claude",
"email": "test@example.com",
"proxy_url": "http://proxy.local",
"prefix": "test-prefix",
"type": "claude",
"email": "test@example.com",
"proxy_url": "http://proxy.local",
"prefix": "test-prefix",
"headers": map[string]string{
" X-Test ": " value ",
"X-Empty": " ",
},
"disable_cooling": true,
"request_retry": 2,
}
@@ -110,6 +114,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
if auths[0].ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
}
if got := auths[0].Attributes["header:X-Test"]; got != "value" {
t.Errorf("expected header:X-Test value, got %q", got)
}
if _, ok := auths[0].Attributes["header:X-Empty"]; ok {
t.Errorf("expected header:X-Empty to be absent, got %q", auths[0].Attributes["header:X-Empty"])
}
if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v {
t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"])
}
@@ -450,8 +460,9 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
Prefix: "test-prefix",
ProxyURL: "http://proxy.local",
Attributes: map[string]string{
"source": "test-source",
"path": "/path/to/auth",
"source": "test-source",
"path": "/path/to/auth",
"header:X-Tra": "value",
},
}
metadata := map[string]any{
@@ -506,6 +517,9 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
if v.Attributes["runtime_only"] != "true" {
t.Error("expected runtime_only=true")
}
if got := v.Attributes["header:X-Tra"]; got != "value" {
t.Errorf("expected virtual %d header:X-Tra %q, got %q", i, "value", got)
}
if v.Attributes["gemini_virtual_parent"] != "primary-id" {
t.Errorf("expected gemini_virtual_parent=primary-id, got %s", v.Attributes["gemini_virtual_parent"])
}

Some files were not shown because too many files have changed in this diff Show More