Compare commits

..

66 Commits

Author SHA1 Message Date
Luis Pater
2df35449fe Fix executor compat helpers 2026-04-02 12:20:12 +08:00
Luis Pater
c744179645 Merge PR #479 2026-04-02 12:15:33 +08:00
Luis Pater
9720b03a6b Merge pull request #477 from ben-vargas/plus-main
fix(copilot): route Gemini preview models to chat endpoint and correct context lengths
2026-04-02 11:36:51 +08:00
Luis Pater
f2c0f3d325 Merge pull request #476 from hungthai1401/fix/ghc-gpt54mini
Fix GitHub Copilot gpt-5.4-mini endpoint routing
2026-04-02 11:36:26 +08:00
Luis Pater
4f99bc54f1 test: update codex header expectations 2026-04-02 11:19:37 +08:00
Luis Pater
913f4a9c5f test: fix executor tests after helpers refactor 2026-04-02 11:12:30 +08:00
Luis Pater
25d1c18a3f fix: scope experimental cch signing to billing header 2026-04-02 11:03:11 +08:00
Luis Pater
d09dd4d0b2 Merge commit '15c2f274ea690c9a7c9db22f9f454af869db5375' into dev 2026-04-02 10:59:54 +08:00
Luis Pater
474fb042da Merge pull request #2476 from router-for-me/cherry-pick/pr-2438-to-dev
Cherry-pick PR #2438 onto dev
2026-04-02 10:36:50 +08:00
Michael
8435c3d7be feat(tui): show time in usage details 2026-04-02 10:35:13 +08:00
Luis Pater
e783d0a62e Merge pull request #2441 from MonsterQiu/issue-2421-alias-before-suspension
fix(auth): resolve oauth aliases before suspension checks
2026-04-02 10:27:39 +08:00
Luis Pater
b05f575e9b Merge pull request #2444 from 0oAstro/fix/codex-nonstream-finish-reason-tool-calls
fix(codex): set finish_reason to "tool_calls" in non-streaming response when tool calls are present
2026-04-02 10:01:25 +08:00
edlsh
15c2f274ea fix: preserve cloak config defaults when mode omitted 2026-04-01 13:20:11 -04:00
edlsh
37249339ac feat: add opt-in experimental Claude cch signing 2026-04-01 13:03:17 -04:00
Luis Pater
c422d16beb Merge pull request #2398 from 7RPH/fix/responses-sse-framing
fix: preserve SSE event boundaries for Responses streams
2026-04-02 00:46:51 +08:00
Luis Pater
66cd50f603 Merge pull request #2468 from router-for-me/ip
fix(openai): improve client IP retrieval in websocket handler
2026-04-02 00:03:35 +08:00
hkfires
caa529c282 fix(openai): improve client IP retrieval in websocket handler 2026-04-01 20:16:01 +08:00
hkfires
51a4379bf4 refactor(openai): remove websocket body log truncation limit 2026-04-01 18:11:43 +08:00
Luis Pater
acf98ed10e fix(openai): add session reference counter and cache lifecycle management for websocket tools 2026-04-01 17:28:50 +08:00
Luis Pater
d1c07a091e fix(openai): add websocket tool call repair with caching and tests to improve transcript consistency 2026-04-01 17:16:49 +08:00
Ben Vargas
c1a8adf1ab feat(registry): add GitHub Copilot gemini-3.1-pro-preview model 2026-04-01 01:25:03 -06:00
Ben Vargas
08e078fc25 fix(openai): route copilot Gemini preview models to chat endpoint 2026-04-01 01:24:58 -06:00
Luis Pater
105a21548f fix(codex): centralize session management with global store and add tests for executor session lifecycle 2026-04-01 13:17:10 +08:00
Luis Pater
1734aa1664 fix(codex): prioritize websocket-enabled credentials across priority tiers in scheduler logic 2026-04-01 12:51:12 +08:00
Luis Pater
ca11b236a7 refactor(runtime, openai): simplify header management and remove redundant websocket logging logic 2026-04-01 11:57:31 +08:00
Luis Pater
330e12d3c2 fix(codex): conditionally set Session_id header for Mac OS user agents and clean up redundant logic 2026-04-01 11:11:45 +08:00
Thai Nguyen Hung
bd09c0bf09 feat(registry): add gpt-5.4-mini model to GitHub Copilot registry 2026-04-01 10:04:38 +07:00
Luis Pater
b468ca79c3 Merge branch 'dev' of github.com:router-for-me/CLIProxyAPI into dev 2026-04-01 03:09:03 +08:00
Luis Pater
d2c7e4e96a refactor(runtime): move executor utilities to helps package and update references 2026-04-01 03:08:20 +08:00
Luis Pater
1c7003ff68 Merge pull request #2452 from Lucaszmv/fix-qwen-cli-v0.13.2
fix(qwen): update CLI simulation to v0.13.2 and adjust header casing
2026-04-01 02:44:27 +08:00
Lucaszmv
1b44364e78 fix(qwen): update CLI simulation to v0.13.2 2026-03-31 15:19:07 -03:00
0oAstro
ec77f4a4f5 fix(codex): set finish_reason to tool_calls in non-streaming response when tool calls are present 2026-03-31 14:12:15 +05:30
MonsterQiu
f611dd6e96 refactor(auth): dedupe route-aware model support checks 2026-03-31 15:42:25 +08:00
MonsterQiu
07b7c1a1e0 fix(auth): resolve oauth aliases before suspension checks 2026-03-31 14:27:14 +08:00
Luis Pater
51fd58d74f fix(codex): use normalizeCodexInstructions to set default instructions 2026-03-31 12:16:57 +08:00
Luis Pater
faae9c2f7c Merge pull request #2422 from MonsterQiu/fix/codex-compact-instructions
fix(codex): add default instructions for /responses/compact
2026-03-31 12:14:20 +08:00
Luis Pater
bc3a6e4646 Merge pull request #2434 from MonsterQiu/fix/codex-responses-null-instructions
fix(codex): normalize null instructions for /responses requests
2026-03-31 12:01:21 +08:00
Luis Pater
b09b03e35e Merge pull request #2424 from possible055/fix/websocket-transcript-replacement
fix(openai): handle transcript replacement after websocket v2 compaction
2026-03-31 11:00:33 +08:00
Luis Pater
16231947e7 Merge pull request #2426 from xixiwenxuanhe/feature/antigravity-credits
feat(antigravity): add AI credits quota fallback
2026-03-31 10:51:40 +08:00
MonsterQiu
39b9a38fbc fix(codex): normalize null instructions across responses paths 2026-03-31 10:32:39 +08:00
MonsterQiu
bd855abec9 fix(codex): normalize null instructions for responses requests 2026-03-31 10:29:02 +08:00
Luis Pater
7c3c2e9f64 Merge pull request #2417 from CharTyr/fix/amp-streaming-thinking-regression
fix(amp): 修复流式响应中 thinking block 被错误抑制导致的 TUI 空白回复
2026-03-31 10:12:13 +08:00
Luis Pater
c10f8ae2e2 Fixed: #2420
docs(readme): remove ProxyPal section from all README translations
2026-03-31 07:23:02 +08:00
xixiwenxuanhe
a0bf33eca6 fix(antigravity): preserve fallback and honor config gate 2026-03-31 00:14:05 +08:00
xixiwenxuanhe
88dd9c715d feat(antigravity): add AI credits quota fallback 2026-03-30 23:58:12 +08:00
apparition
a3e21df814 fix(openai): avoid developer transcript resets
- Narrow websocket transcript replacement detection to assistant outputs and function calls
- Preserve existing merge behavior for follow-up developer messages without previous_response_id
- Add a regression test covering mid-session developer message updates
2026-03-30 23:33:16 +08:00
MonsterQiu
d3b94c9241 fix(codex): normalize null instructions for compact requests 2026-03-30 22:58:05 +08:00
apparition
c1d7599829 fix(openai): handle transcript replacement after websocket compaction
- Add shouldReplaceWebsocketTranscript() to detect historical model output in input
- Add normalizeResponseTranscriptReplacement() for full transcript reset handling
- Prevent duplicate stale turn-state when clients replace local history post-compaction
- Avoid orphaned function_call items from incremental append on compact transcripts
- Add unit tests for transcript replacement detection and state reset behavior
2026-03-30 22:44:58 +08:00
MonsterQiu
d11936f292 fix(codex): add default instructions for /responses/compact 2026-03-30 22:44:46 +08:00
Luis Pater
17363edf25 fix(auth): skip downtime for request-scoped 404 errors in model state management 2026-03-30 22:22:42 +08:00
CharTyr
279cbbbb8a fix(amp): don't suppress thinking blocks in streaming mode
Reverts the streaming thinking suppression introduced in b15453c.
rewriteStreamEvent should only inject signatures and rewrite model
names — suppressing thinking blocks in streaming mode breaks SSE
index alignment and causes the Amp TUI to render empty responses
on the second message onward (especially with model-mapped
non-Claude providers like GPT-5.4).

Non-streaming responses still suppress thinking when tool_use is
present via rewriteModelInResponse.
2026-03-30 20:09:32 +08:00
Luis Pater
486cd4c343 Merge pull request #2409 from sususu98/fix/tool-use-pairing-break
fix(antigravity): reorder model parts to prevent tool_use↔tool_result pairing breakage
2026-03-30 16:59:46 +08:00
sususu98
25feceb783 fix(antigravity): reorder model parts to prevent tool_use↔tool_result pairing breakage
When a Claude assistant message contains [text, tool_use, text], the
Antigravity API internally splits the model message at functionCall
boundaries, creating an extra assistant turn between tool_use and the
following tool_result. Claude then rejects with:

  tool_use ids were found without tool_result blocks immediately after

Fix: extend the existing 2-way part reordering (thinking-first) to a
3-way partition: thinking → regular → functionCall. This ensures
functionCall parts are always last, so Antigravity's split cannot
insert an extra assistant turn before the user's tool_result.

Fixes #989
2026-03-30 15:09:33 +08:00
Luis Pater
d26752250d Merge pull request #2403 from CharTyr/clean-pr
fix(amp): 修复Amp CLI 集成 缺失/无效 signature 导致的 TUI 崩溃与上游 400 问题
2026-03-30 12:54:15 +08:00
CharTyr
b15453c369 fix(amp): address PR review - stream thinking suppression, SSE detection, test init
- Call suppressAmpThinking in rewriteStreamEvent for streaming path
- Handle nil return from suppressAmpThinking to skip suppressed events
- Narrow looksLikeSSEChunk to line-prefix detection (HasPrefix vs Contains)
- Initialize suppressedContentBlock map in test
2026-03-30 00:42:04 -04:00
CharTyr
04ba8c8bc3 feat(amp): sanitize signatures and handle stream suppression for Amp compatibility 2026-03-29 22:23:18 -04:00
Luis Pater
6570692291 Merge pull request #2400 from router-for-me/revert-2374-codex-cache-clean
Revert "fix(codex): restore prompt cache continuity for Codex requests"
2026-03-29 22:19:39 +08:00
trph
f73d55ddaa fix: simplify responses SSE suffix handling 2026-03-29 22:19:25 +08:00
Luis Pater
13aa5b3375 Revert "fix(codex): restore prompt cache continuity for Codex requests" 2026-03-29 22:18:14 +08:00
trph
0fcc02fbea fix: tighten responses SSE review follow-up 2026-03-29 22:10:28 +08:00
trph
c03883ccf0 fix: address responses SSE review feedback 2026-03-29 22:00:46 +08:00
trph
134a9eac9d fix: preserve SSE event boundaries for Responses streams 2026-03-29 17:23:16 +08:00
Luis Pater
6d8de0ade4 feat(auth): implement weighted provider rotation for improved scheduling fairness 2026-03-29 13:49:01 +08:00
Luis Pater
1587ff5e74 Merge pull request #2389 from router-for-me/claude
fix(claude): add default max_tokens for models
2026-03-29 13:03:20 +08:00
hkfires
f033d3a6df fix(claude): enhance ensureModelMaxTokens to use registered max_completion_tokens and fallback to default 2026-03-29 13:00:43 +08:00
hkfires
145e0e0b5d fix(claude): add default max_tokens for models 2026-03-29 12:46:00 +08:00
70 changed files with 5182 additions and 1807 deletions

View File

@@ -14,6 +14,95 @@ This project only accepts pull requests that relate to third-party provider supp
If you need to submit any non-third-party provider changes, please open them against the [mainline](https://github.com/router-for-me/CLIProxyAPI) repository.
1. Fork the repository
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
4. Push to the branch (`git push origin feature/amazing-feature`)
5. Open a Pull Request
## Who is with us?
Those projects are based on CLIProxyAPI:
### [vibeproxy](https://github.com/automazeio/vibeproxy)
Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
### [Quotio](https://github.com/nguyenphutrong/quotio)
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
### [CodMate](https://github.com/loocor/CodMate)
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
Windows desktop app built with Tauri + React for monitoring AI coding assistant quotas via CLIProxyAPI. Track usage across Gemini, Claude, OpenAI Codex, and Antigravity accounts with real-time dashboard, system tray integration, and one-click proxy control - no API keys needed.
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
Browser extension for one-stop management of New API-compatible relay site accounts, featuring balance and usage dashboards, auto check-in, one-click key export to common apps, in-page API availability testing, and channel/model sync and redirection. It integrates with CLIProxyAPI through the Management API for one-click provider import and config sync.
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AI is an AI assistant tool designed specifically for restricted environments. It provides a stealthy operation
mode without windows or traces, and enables cross-device AI Q&A interaction and control via the local area network (
LAN). Essentially, it is an automated collaboration layer of "screen/audio capture + AI inference + low-friction delivery",
helping users to immersively use AI assistants across applications on controlled devices or in restricted environments.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
## More choices
Those projects are ports of CLIProxyAPI or inspired by it:
### [9Router](https://github.com/decolua/9router)
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
Never stop coding. Smart routing to FREE & low-cost AI models with automatic fallback.
OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoint with smart routing, load balancing, retries, and fallbacks. Add policies, rate limits, caching, and observability for reliable, cost-aware inference.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

View File

@@ -12,6 +12,92 @@
如果需要提交任何非第三方供应商支持的 Pull Request请提交到[主线](https://github.com/router-for-me/CLIProxyAPI)版本。
1. Fork 仓库
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`
3. 提交您的更改(`git commit -m 'Add some amazing feature'`
4. 推送到分支(`git push origin feature/amazing-feature`
5. 打开 Pull Request
## 谁与我们在一起?
这些项目基于 CLIProxyAPI:
### [vibeproxy](https://github.com/automazeio/vibeproxy)
一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型Gemini, Codex, Antigravity无需 API 密钥。
### [Quotio](https://github.com/nguyenphutrong/quotio)
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
### [CodMate](https://github.com/loocor/CodMate)
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话Claude Code、Codex、Gemini CLI提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI 监控 AI 编程助手配额。支持跨 Gemini、Claude、OpenAI Codex 和 Antigravity 账户的使用量追踪,提供实时仪表盘、系统托盘集成和一键代理控制,无需 API 密钥。
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
Windows 托盘应用,基于 PowerShell 脚本实现不依赖任何第三方库。主要功能包括自动创建快捷方式、静默运行、密码管理、通道切换Main / Plus以及自动下载与更新。
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君是一款用于管理AI编程助手的跨平台桌面应用支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具本地代理实现多账户配额跟踪和一键配置。
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
用于一站式管理 New API 兼容中转站账号的浏览器扩展,提供余额与用量看板、自动签到、密钥一键导出到常用应用、网页内 API 可用性测试,以及渠道与模型同步和重定向。支持通过 CLIProxyAPI Management API 一键导入 Provider 与同步配置。
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AI 是一款专为受限环境设计的 AI 辅助工具。提供无窗口、无痕迹的隐蔽运行方式,并通过局域网实现跨设备的 AI 问答交互与控制。本质上是一个「屏幕/音频采集 + AI 推理 + 低摩擦投送」的自动化协作层,帮助用户在受控设备/受限环境下沉浸式跨应用地使用 AI 助手。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。
## 更多选择
以下项目是 CLIProxyAPI 的移植版或受其启发:
### [9Router](https://github.com/decolua/9router)
基于 Next.js 的实现,灵感来自 CLIProxyAPI易于安装使用自研格式转换OpenAI/Claude/Gemini/Ollama、组合系统与自动回退、多账户管理指数退避、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
代码不止,创新不停。智能路由至免费及低成本 AI 模型,并支持自动故障转移。
OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼容 OpenAI 的端点,具备智能路由、负载均衡、重试及回退机制。通过添加策略、速率限制、缓存和可观测性,确保推理过程既可靠又具备成本意识。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -126,10 +126,6 @@ CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデルGemini、Codex、Antigravityを即座に切り替えるCLIラッパー - APIキー不要
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
CLIProxyAPI管理用のmacOSネイティブGUIOAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
### [Quotio](https://github.com/nguyenphutrong/quotio)
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要

View File

@@ -96,6 +96,7 @@ max-retry-interval: 30
quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
antigravity-credits: true # Whether to retry Antigravity quota_exhausted 429s once with enabledCreditTypes=["GOOGLE_ONE_AI"]
# Routing strategy for selecting credentials when multiple match.
routing:
@@ -177,6 +178,8 @@ nonstream-keepalive-interval: 0
# - "API"
# - "proxy"
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# experimental-cch-signing: false # optional: default is false; when true, sign the final /v1/messages body using the current Claude Code cch algorithm
# # keep this disabled unless you explicitly need the behavior, so upstream seed changes fall back to legacy proxy behavior
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks

1
go.mod
View File

@@ -83,6 +83,7 @@ require (
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pierrec/xxHash v0.1.5
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect

2
go.sum
View File

@@ -154,6 +154,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pierrec/xxHash v0.1.5 h1:n/jBpwTHiER4xYvK3/CdPVnLDPchj8eTJFFLUb4QHBo=
github.com/pierrec/xxHash v0.1.5/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=

View File

@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
return
}
// Sanitize request body: remove thinking blocks with invalid signatures
// to prevent upstream API 400 errors
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
// Restore the body for the handler to read
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
@@ -259,10 +263,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
} else if len(providers) > 0 {
// Log: Using local provider (free)
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
// Wrap with ResponseRewriter for local providers too, because upstream
// proxies (e.g. NewAPI) may return a different model name and lack
// Amp-required fields like thinking.signature.
rewriter := NewResponseRewriter(c.Writer, modelName)
c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
rewriter.Flush()
} else {
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))

View File

@@ -2,6 +2,7 @@ package amp
import (
"bytes"
"fmt"
"net/http"
"strings"
@@ -12,20 +13,23 @@ import (
)
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
// It's used to rewrite model names in responses when model mapping is used
// It is used to rewrite model names in responses when model mapping is used
// and to keep Amp-compatible response shapes.
type ResponseRewriter struct {
gin.ResponseWriter
body *bytes.Buffer
originalModel string
isStreaming bool
body *bytes.Buffer
originalModel string
isStreaming bool
suppressedContentBlock map[int]struct{}
}
// NewResponseRewriter creates a new response rewriter for model name substitution
// NewResponseRewriter creates a new response rewriter for model name substitution.
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
return &ResponseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
suppressedContentBlock: make(map[int]struct{}),
}
}
@@ -33,15 +37,15 @@ const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
func looksLikeSSEChunk(data []byte) bool {
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
// Heuristics are intentionally simple and cheap.
return bytes.Contains(data, []byte("data:")) ||
bytes.Contains(data, []byte("event:")) ||
bytes.Contains(data, []byte("message_start")) ||
bytes.Contains(data, []byte("message_delta")) ||
bytes.Contains(data, []byte("content_block_start")) ||
bytes.Contains(data, []byte("content_block_delta")) ||
bytes.Contains(data, []byte("content_block_stop")) ||
bytes.Contains(data, []byte("\n\n"))
// We conservatively detect SSE by checking for "data:" / "event:" at the start of any line.
for _, line := range bytes.Split(data, []byte("\n")) {
trimmed := bytes.TrimSpace(line)
if bytes.HasPrefix(trimmed, []byte("data:")) ||
bytes.HasPrefix(trimmed, []byte("event:")) {
return true
}
}
return false
}
func (rw *ResponseRewriter) enableStreaming(reason string) error {
@@ -106,7 +110,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
return rw.body.Write(data)
}
// Flush writes the buffered response with model names rewritten
func (rw *ResponseRewriter) Flush() {
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
@@ -115,26 +118,68 @@ func (rw *ResponseRewriter) Flush() {
return
}
if rw.body.Len() > 0 {
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
// Update Content-Length to match the rewritten body size, since
// signature injection and model name changes alter the payload length.
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
for index, block := range gjson.GetBytes(data, "content").Array() {
blockType := block.Get("type").String()
if blockType != "tool_use" && blockType != "thinking" {
continue
}
signaturePath := fmt.Sprintf("content.%d.signature", index)
if gjson.GetBytes(data, signaturePath).Exists() {
continue
}
var err error
data, err = sjson.SetBytes(data, signaturePath, "")
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
break
}
}
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
var err error
data, err = sjson.SetBytes(data, "content_block.signature", "")
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
}
}
return data
}
func (rw *ResponseRewriter) markSuppressedContentBlock(index int) {
if rw.suppressedContentBlock == nil {
rw.suppressedContentBlock = make(map[int]struct{})
}
rw.suppressedContentBlock[index] = struct{}{}
}
func (rw *ResponseRewriter) isSuppressedContentBlock(index int) bool {
_, ok := rw.suppressedContentBlock[index]
return ok
}
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount {
var err error
data, err = sjson.SetBytes(data, "content", filtered.Value())
@@ -142,13 +187,41 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
}
}
}
}
eventType := gjson.GetBytes(data, "type").String()
indexResult := gjson.GetBytes(data, "index")
if eventType == "content_block_start" && gjson.GetBytes(data, "content_block.type").String() == "thinking" && indexResult.Exists() {
rw.markSuppressedContentBlock(int(indexResult.Int()))
return nil
}
if gjson.GetBytes(data, "delta.type").String() == "thinking_delta" {
if indexResult.Exists() {
rw.markSuppressedContentBlock(int(indexResult.Int()))
}
return nil
}
if eventType == "content_block_stop" && indexResult.Exists() {
index := int(indexResult.Int())
if rw.isSuppressedContentBlock(index) {
delete(rw.suppressedContentBlock, index)
return nil
}
}
return data
}
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
}
if rw.originalModel == "" {
return data
}
@@ -160,24 +233,148 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.originalModel == "" {
return chunk
lines := bytes.Split(chunk, []byte("\n"))
var out [][]byte
i := 0
for i < len(lines) {
line := lines[i]
trimmed := bytes.TrimSpace(line)
// Case 1: "event:" line - look ahead for its "data:" line
if bytes.HasPrefix(trimmed, []byte("event: ")) {
// Scan forward past blank lines to find the data: line
dataIdx := -1
for j := i + 1; j < len(lines); j++ {
t := bytes.TrimSpace(lines[j])
if len(t) == 0 {
continue
}
if bytes.HasPrefix(t, []byte("data: ")) {
dataIdx = j
}
break
}
if dataIdx >= 0 {
// Found event+data pair - process through rewriter
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
rewritten := rw.rewriteStreamEvent(jsonData)
// Emit event line
out = append(out, line)
// Emit blank lines between event and data
for k := i + 1; k < dataIdx; k++ {
out = append(out, lines[k])
}
// Emit rewritten data
out = append(out, append([]byte("data: "), rewritten...))
i = dataIdx + 1
continue
}
}
// No data line found (orphan event from cross-chunk split)
// Pass it through as-is - the data will arrive in the next chunk
out = append(out, line)
i++
continue
}
// Case 2: standalone "data:" line (no preceding event: in this chunk)
if bytes.HasPrefix(trimmed, []byte("data: ")) {
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
rewritten := rw.rewriteStreamEvent(jsonData)
out = append(out, append([]byte("data: "), rewritten...))
i++
continue
}
}
// Case 3: everything else
out = append(out, line)
i++
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
return bytes.Join(out, []byte("\n"))
}
// rewriteStreamEvent processes a single JSON event in the SSE stream.
// It rewrites model names and ensures signature fields exist.
// NOTE: streaming mode does NOT suppress thinking blocks - they are
// passed through with signature injection to avoid breaking SSE index
// alignment and TUI rendering.
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Inject empty signature where needed
data = ensureAmpSignature(data)
// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.originalModel)
}
}
}
return bytes.Join(lines, []byte("\n"))
return data
}
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
// from the messages array in a request body before forwarding to the upstream API.
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
func SanitizeAmpRequestBody(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
modified := false
for msgIdx, msg := range messages.Array() {
if msg.Get("role").String() != "assistant" {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
var keepBlocks []interface{}
removedCount := 0
for _, block := range content.Array() {
blockType := block.Get("type").String()
if blockType == "thinking" {
sig := block.Get("signature")
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
removedCount++
continue
}
}
keepBlocks = append(keepBlocks, block.Value())
}
if removedCount > 0 {
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
var err error
if len(keepBlocks) == 0 {
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
} else {
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
}
if err != nil {
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
continue
}
modified = true
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
}
}
if modified {
log.Debugf("Amp RequestSanitizer: sanitized request body")
}
return body
}

View File

@@ -1,6 +1,7 @@
package amp
import (
"strings"
"testing"
)
@@ -100,6 +101,50 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
}
}
func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) {
rw := &ResponseRewriter{suppressedContentBlock: make(map[int]struct{})}
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
result := rw.rewriteStreamChunk(chunk)
// Streaming mode preserves thinking blocks (does NOT suppress them)
// to avoid breaking SSE index alignment and TUI rendering
if !contains(result, []byte(`"content_block":{"type":"thinking"`)) {
t.Fatalf("expected thinking content_block_start to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"delta":{"type":"thinking_delta"`)) {
t.Fatalf("expected thinking_delta to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"type":"content_block_stop","index":0`)) {
t.Fatalf("expected content_block_stop for thinking block to be preserved, got %s", string(result))
}
if !contains(result, []byte(`"content_block":{"type":"tool_use"`)) {
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
}
// Signature should be injected into both thinking and tool_use blocks
if count := strings.Count(string(result), `"signature":""`); count != 2 {
t.Fatalf("expected 2 signature injections, but got %d in %s", count, string(result))
}
}
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte("drop-whitespace")) {
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
}
if contains(result, []byte("drop-number")) {
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
}
if !contains(result, []byte("keep-valid")) {
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
}
if !contains(result, []byte("keep-text")) {
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {

View File

@@ -211,6 +211,10 @@ type QuotaExceeded struct {
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
// AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once
// on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"].
AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"`
}
// RoutingConfig configures how credentials are selected for requests.
@@ -257,8 +261,8 @@ type AmpCode struct {
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
@@ -380,6 +384,11 @@ type ClaudeKey struct {
// Cloak configures request cloaking for non-Claude-Code clients.
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
// ExperimentalCCHSigning enables opt-in final-body cch signing for cloaked
// Claude /v1/messages requests. It is disabled by default so upstream seed
// changes do not alter the proxy's legacy behavior.
ExperimentalCCHSigning bool `yaml:"experimental-cch-signing,omitempty" json:"experimental-cch-signing,omitempty"`
}
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }

View File

@@ -477,6 +477,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.4-mini",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.4 mini",
Description: "OpenAI GPT-5.4 mini via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "claude-haiku-4.5",
Object: "model",
@@ -571,6 +584,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-3-pro-preview",
@@ -582,6 +596,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-3.1-pro-preview",
@@ -591,8 +606,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot",
DisplayName: "Gemini 3.1 Pro (Preview)",
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
ContextLength: 173000,
MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-3-flash-preview",
@@ -602,8 +618,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
Type: "github-copilot",
DisplayName: "Gemini 3 Flash (Preview)",
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
ContextLength: 1048576,
ContextLength: 173000,
MaxCompletionTokens: 65536,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "grok-code-fast-1",

View File

@@ -0,0 +1,29 @@
package registry
import "testing"
func TestGitHubCopilotGeminiModelsAreChatOnly(t *testing.T) {
models := GetGitHubCopilotModels()
required := map[string]bool{
"gemini-2.5-pro": false,
"gemini-3-pro-preview": false,
"gemini-3.1-pro-preview": false,
"gemini-3-flash-preview": false,
}
for _, model := range models {
if _, ok := required[model.ID]; !ok {
continue
}
required[model.ID] = true
if len(model.SupportedEndpoints) != 1 || model.SupportedEndpoints[0] != "/chat/completions" {
t.Fatalf("model %q supported endpoints = %v, want [/chat/completions]", model.ID, model.SupportedEndpoints)
}
}
for modelID, found := range required {
if !found {
t.Fatalf("expected GitHub Copilot model %q in definitions", modelID)
}
}
}

View File

@@ -14,6 +14,7 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -115,8 +116,8 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, false)
if err != nil {
@@ -137,7 +138,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
@@ -151,17 +152,17 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
if len(wsResp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
helps.AppendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
}
if wsResp.Status < 200 || wsResp.Status >= 300 {
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
}
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
reporter.Publish(ctx, helps.ParseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
@@ -174,8 +175,8 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, true)
if err != nil {
@@ -195,7 +196,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
@@ -208,24 +209,24 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
})
wsStream, err := e.relay.Stream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
firstEvent, ok := <-wsStream
if !ok {
err = fmt.Errorf("wsrelay: stream closed before start")
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK {
metadataLogged := false
if firstEvent.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
metadataLogged = true
}
var body bytes.Buffer
if len(firstEvent.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
helps.AppendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
body.Write(firstEvent.Payload)
}
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
@@ -233,18 +234,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
for event := range wsStream {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
if body.Len() == 0 {
body.WriteString(event.Err.Error())
}
break
}
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
body.Write(event.Payload)
}
if event.Type == wsrelay.MessageTypeStreamEnd {
@@ -260,23 +261,23 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
metadataLogged := false
processEvent := func(event wsrelay.StreamEvent) bool {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false
}
switch event.Type {
case wsrelay.MessageTypeStreamStart:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
filtered := FilterSSEUsageMetadata(event.Payload)
if detail, ok := parseGeminiStreamUsage(filtered); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
filtered := helps.FilterSSEUsageMetadata(event.Payload)
if detail, ok := helps.ParseGeminiStreamUsage(filtered); ok {
reporter.Publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param)
for i := range lines {
@@ -288,21 +289,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return false
case wsrelay.MessageTypeHTTPResp:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
}
reporter.publish(ctx, parseGeminiUsage(event.Payload))
reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
return false
case wsrelay.MessageTypeError:
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false
}
@@ -345,7 +346,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
@@ -358,12 +359,12 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
})
resp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
if len(resp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
helps.AppendAPIResponseChunk(ctx, e.cfg, resp.Body)
}
if resp.Status < 200 || resp.Status >= 300 {
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
@@ -404,8 +405,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
return nil, translatedPayload{}, err
}
payload = fixGeminiImageAspectRatio(baseModel, payload)
requestedModel := payloadRequestedModel(opts, req.Model)
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")

View File

@@ -24,6 +24,7 @@ import (
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -47,12 +48,41 @@ const (
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
antigravityCreditsRetryTTL = 5 * time.Hour
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
)
type antigravity429Category string
const (
antigravity429Unknown antigravity429Category = "unknown"
antigravity429RateLimited antigravity429Category = "rate_limited"
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
)
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
antigravityCreditsExhaustedByAuth sync.Map
antigravityPreferCreditsByModel sync.Map
antigravityQuotaExhaustedKeywords = []string{
"quota_exhausted",
"quota exhausted",
}
antigravityCreditsExhaustedKeywords = []string{
"google_one_ai",
"insufficient credit",
"insufficient credits",
"not enough credit",
"not enough credits",
"credit exhausted",
"credits exhausted",
"credit balance",
"minimumcreditamountforusage",
"minimum credit amount for usage",
"minimum credit",
"resource has been exhausted",
}
)
// AntigravityExecutor proxies requests to the antigravity upstream.
@@ -113,7 +143,7 @@ func initAntigravityTransport() {
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
antigravityTransportOnce.Do(initAntigravityTransport)
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
client := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
// If no transport is set, use the shared HTTP/1.1 transport.
if client.Transport == nil {
client.Transport = antigravityTransport
@@ -183,6 +213,231 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
return httpClient.Do(httpReq)
}
func injectEnabledCreditTypes(payload []byte) []byte {
if len(payload) == 0 {
return nil
}
if !gjson.ValidBytes(payload) {
return nil
}
updated, err := sjson.SetRawBytes(payload, "enabledCreditTypes", []byte(`["GOOGLE_ONE_AI"]`))
if err != nil {
return nil
}
return updated
}
func classifyAntigravity429(body []byte) antigravity429Category {
if len(body) == 0 {
return antigravity429Unknown
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityQuotaExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
return antigravity429QuotaExhausted
}
}
status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String())
if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") {
return antigravity429Unknown
}
details := gjson.GetBytes(body, "error.details")
if !details.Exists() || !details.IsArray() {
return antigravity429Unknown
}
for _, detail := range details.Array() {
if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" {
continue
}
reason := strings.TrimSpace(detail.Get("reason").String())
if strings.EqualFold(reason, "QUOTA_EXHAUSTED") {
return antigravity429QuotaExhausted
}
if strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED") {
return antigravity429RateLimited
}
}
return antigravity429Unknown
}
func antigravityCreditsRetryEnabled(cfg *config.Config) bool {
return cfg != nil && cfg.QuotaExceeded.AntigravityCredits
}
func antigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) bool {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return false
}
value, ok := antigravityCreditsExhaustedByAuth.Load(auth.ID)
if !ok {
return false
}
until, ok := value.(time.Time)
if !ok || until.IsZero() {
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
return false
}
if !until.After(now) {
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
return false
}
return true
}
func markAntigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return
}
antigravityCreditsExhaustedByAuth.Store(auth.ID, now.Add(antigravityCreditsRetryTTL))
}
func clearAntigravityCreditsExhausted(auth *cliproxyauth.Auth) {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return
}
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
}
func antigravityPreferCreditsKey(auth *cliproxyauth.Auth, modelName string) string {
if auth == nil {
return ""
}
authID := strings.TrimSpace(auth.ID)
modelName = strings.TrimSpace(modelName)
if authID == "" || modelName == "" {
return ""
}
return authID + "|" + modelName
}
func antigravityShouldPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time) bool {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return false
}
value, ok := antigravityPreferCreditsByModel.Load(key)
if !ok {
return false
}
until, ok := value.(time.Time)
if !ok || until.IsZero() {
antigravityPreferCreditsByModel.Delete(key)
return false
}
if !until.After(now) {
antigravityPreferCreditsByModel.Delete(key)
return false
}
return true
}
func markAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time, retryAfter *time.Duration) {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return
}
until := now.Add(antigravityCreditsRetryTTL)
if retryAfter != nil && *retryAfter > 0 {
until = now.Add(*retryAfter)
}
antigravityPreferCreditsByModel.Store(key, until)
}
func clearAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string) {
key := antigravityPreferCreditsKey(auth, modelName)
if key == "" {
return
}
antigravityPreferCreditsByModel.Delete(key)
}
func shouldMarkAntigravityCreditsExhausted(statusCode int, body []byte, reqErr error) bool {
if reqErr != nil || statusCode == 0 {
return false
}
if statusCode >= http.StatusInternalServerError || statusCode == http.StatusRequestTimeout {
return false
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityCreditsExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
return true
}
}
return false
}
func newAntigravityStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if statusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil {
err.retryAfter = retryAfter
}
}
return err
}
func (e *AntigravityExecutor) attemptCreditsFallback(
ctx context.Context,
auth *cliproxyauth.Auth,
httpClient *http.Client,
token string,
modelName string,
payload []byte,
stream bool,
alt string,
baseURL string,
originalBody []byte,
) (*http.Response, bool) {
if !antigravityCreditsRetryEnabled(e.cfg) {
return nil, false
}
if classifyAntigravity429(originalBody) != antigravity429QuotaExhausted {
return nil, false
}
now := time.Now()
if antigravityCreditsExhausted(auth, now) {
return nil, false
}
creditsPayload := injectEnabledCreditTypes(payload)
if len(creditsPayload) == 0 {
return nil, false
}
httpReq, errReq := e.buildRequest(ctx, auth, token, modelName, creditsPayload, stream, alt, baseURL)
if errReq != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errReq)
return nil, true
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, true
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
retryAfter, _ := parseRetryDelay(originalBody)
markAntigravityPreferCredits(auth, modelName, now, retryAfter)
clearAntigravityCreditsExhausted(auth)
return httpResp, true
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close credits fallback response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, true
}
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, modelName)
markAntigravityCreditsExhausted(auth, now)
}
return nil, true
}
// Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
@@ -203,8 +458,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
@@ -222,8 +477,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -237,7 +492,15 @@ attemptLoop:
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, false, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
@@ -245,7 +508,7 @@ attemptLoop:
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
@@ -260,20 +523,50 @@ attemptLoop:
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, false, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
helps.RecordAPIResponseMetadata(ctx, e.cfg, creditsResp.StatusCode, creditsResp.Header.Clone())
creditsBody, errCreditsRead := io.ReadAll(creditsResp.Body)
if errClose := creditsResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close credits success response body error: %v", errClose)
}
if errCreditsRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errCreditsRead)
err = errCreditsRead
return resp, err
}
helps.AppendAPIResponseChunk(ctx, e.cfg, creditsBody)
reporter.Publish(ctx, helps.ParseAntigravityUsage(creditsBody))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, creditsBody, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: creditsResp.Header.Clone()}
reporter.EnsurePublished(ctx)
return resp, nil
}
}
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
@@ -295,33 +588,21 @@ attemptLoop:
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
return resp, err
}
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
return resp, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(lastStatus, lastBody)
case lastErr != nil:
err = lastErr
default:
@@ -345,8 +626,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
@@ -364,8 +645,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -379,7 +660,15 @@ attemptLoop:
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
@@ -387,7 +676,7 @@ attemptLoop:
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
@@ -401,14 +690,14 @@ attemptLoop:
err = errDo
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return resp, err
@@ -427,7 +716,24 @@ attemptLoop:
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
httpResp = creditsResp
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
}
}
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
goto streamSuccessClaudeNonStream
}
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
@@ -449,16 +755,11 @@ attemptLoop:
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
return resp, err
}
streamSuccessClaudeNonStream:
out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) {
defer close(out)
@@ -471,29 +772,29 @@ attemptLoop:
scanner.Buffer(nil, streamScannerBuffer)
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
line = helps.FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
payload := helps.JSONPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
reporter.Publish(ctx, detail)
}
out <- cliproxyexecutor.StreamChunk{Payload: payload}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
}
}(httpResp)
@@ -509,24 +810,18 @@ attemptLoop:
}
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
reporter.Publish(ctx, helps.ParseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
return resp, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(lastStatus, lastBody)
case lastErr != nil:
err = lastErr
default:
@@ -748,8 +1043,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
@@ -767,8 +1062,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
@@ -782,14 +1077,22 @@ attemptLoop:
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
requestPayload := translated
usedCreditsDirect := false
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
requestPayload = creditsPayload
usedCreditsDirect = true
}
}
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return nil, err
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil, errDo
}
@@ -803,14 +1106,14 @@ attemptLoop:
err = errDo
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return nil, err
@@ -829,7 +1132,24 @@ attemptLoop:
err = errRead
return nil, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode == http.StatusTooManyRequests {
if usedCreditsDirect {
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
clearAntigravityPreferCredits(auth, baseModel)
markAntigravityCreditsExhausted(auth, time.Now())
}
} else {
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
if creditsResp != nil {
httpResp = creditsResp
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
}
}
}
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
goto streamSuccessExecuteStream
}
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
@@ -851,16 +1171,11 @@ attemptLoop:
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
return nil, err
}
streamSuccessExecuteStream:
out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) {
defer close(out)
@@ -874,19 +1189,19 @@ attemptLoop:
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
line = helps.FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
payload := helps.JSONPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
reporter.Publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param)
@@ -899,11 +1214,11 @@ attemptLoop:
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
}
}(httpResp)
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -911,13 +1226,7 @@ attemptLoop:
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
err = newAntigravityStatusErr(lastStatus, lastBody)
case lastErr != nil:
err = lastErr
default:
@@ -1012,7 +1321,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
httpReq.Host = host
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: requestURL.String(),
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -1026,7 +1335,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return cliproxyexecutor.Response{}, errDo
}
@@ -1040,16 +1349,16 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
return cliproxyexecutor.Response{}, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
@@ -1316,7 +1625,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
if e.cfg != nil && e.cfg.RequestLog {
payloadLog = []byte(payloadStr)
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: requestURL.String(),
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -1479,7 +1788,7 @@ func antigravityWait(ctx context.Context, wait time.Duration) error {
}
}
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
var antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
return []string{base}
}

View File

@@ -0,0 +1,423 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
func resetAntigravityCreditsRetryState() {
antigravityCreditsExhaustedByAuth = sync.Map{}
antigravityPreferCreditsByModel = sync.Map{}
}
func TestClassifyAntigravity429(t *testing.T) {
t.Run("quota exhausted", func(t *testing.T) {
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
}
})
t.Run("structured rate limit", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
if got := classifyAntigravity429(body); got != antigravity429RateLimited {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429RateLimited)
}
})
t.Run("structured quota exhausted", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "QUOTA_EXHAUSTED"}
]
}
}`)
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
}
})
t.Run("unknown", func(t *testing.T) {
body := []byte(`{"error":{"message":"too many requests"}}`)
if got := classifyAntigravity429(body); got != antigravity429Unknown {
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429Unknown)
}
})
}
func TestInjectEnabledCreditTypes(t *testing.T) {
body := []byte(`{"model":"gemini-2.5-flash","request":{}}`)
got := injectEnabledCreditTypes(body)
if got == nil {
t.Fatal("injectEnabledCreditTypes() returned nil")
}
if !strings.Contains(string(got), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("injectEnabledCreditTypes() = %s, want enabledCreditTypes", string(got))
}
if got := injectEnabledCreditTypes([]byte(`not json`)); got != nil {
t.Fatalf("injectEnabledCreditTypes() for invalid json = %s, want nil", string(got))
}
}
func TestShouldMarkAntigravityCreditsExhausted(t *testing.T) {
for _, body := range [][]byte{
[]byte(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`),
[]byte(`{"error":{"message":"minimumCreditAmountForUsage requirement not met"}}`),
[]byte(`{"error":{"message":"Resource has been exhausted"}}`),
} {
if !shouldMarkAntigravityCreditsExhausted(http.StatusForbidden, body, nil) {
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body))
}
}
if shouldMarkAntigravityCreditsExhausted(http.StatusServiceUnavailable, []byte(`{"error":{"message":"credits exhausted"}}`), nil) {
t.Fatal("shouldMarkAntigravityCreditsExhausted() = true for 5xx, want false")
}
}
func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
requestBodies []string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
requestBodies = append(requestBodies, string(body))
reqNum := len(requestBodies)
mu.Unlock()
if reqNum == 1 {
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
return
}
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("second request body missing enabledCreditTypes: %s", string(body))
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-credits-ok",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("Execute() returned empty payload")
}
mu.Lock()
defer mu.Unlock()
if len(requestBodies) != 2 {
t.Fatalf("request count = %d, want 2", len(requestBodies))
}
}
func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var requestCount int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-credits-exhausted",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
markAntigravityCreditsExhausted(auth, time.Now())
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err == nil {
t.Fatal("Execute() error = nil, want 429")
}
sErr, ok := err.(statusErr)
if !ok {
t.Fatalf("Execute() error type = %T, want statusErr", err)
}
if got := sErr.StatusCode(); got != http.StatusTooManyRequests {
t.Fatalf("Execute() status code = %d, want %d", got, http.StatusTooManyRequests)
}
if requestCount != 1 {
t.Fatalf("request count = %d, want 1", requestCount)
}
}
func TestAntigravityExecute_PrefersCreditsAfterSuccessfulFallback(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
requestBodies []string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
requestBodies = append(requestBodies, string(body))
reqNum := len(requestBodies)
mu.Unlock()
switch reqNum {
case 1:
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"10s"}]}}`))
case 2, 3:
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("request %d body missing enabledCreditTypes: %s", reqNum, string(body))
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"OK"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
default:
t.Fatalf("unexpected request count %d", reqNum)
}
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-prefer-credits",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
request := cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}
opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatAntigravity}
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
t.Fatalf("first Execute() error = %v", err)
}
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
t.Fatalf("second Execute() error = %v", err)
}
mu.Lock()
defer mu.Unlock()
if len(requestBodies) != 3 {
t.Fatalf("request count = %d, want 3", len(requestBodies))
}
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("first request unexpectedly used credits: %s", requestBodies[0])
}
if !strings.Contains(requestBodies[1], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("fallback request missing credits: %s", requestBodies[1])
}
if !strings.Contains(requestBodies[2], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("preferred request missing credits: %s", requestBodies[2])
}
}
func TestAntigravityExecute_PreservesBaseURLFallbackAfterCreditsRetryFailure(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var (
mu sync.Mutex
firstCount int
secondCount int
)
firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
firstCount++
reqNum := firstCount
mu.Unlock()
switch reqNum {
case 1:
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"}]}}`))
case 2:
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("credits retry missing enabledCreditTypes: %s", string(body))
}
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":{"message":"permission denied"}}`))
default:
t.Fatalf("unexpected first server request count %d", reqNum)
}
}))
defer firstServer.Close()
secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
secondCount++
mu.Unlock()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
}))
defer secondServer.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
})
auth := &cliproxyauth.Auth{
ID: "auth-baseurl-fallback",
Attributes: map[string]string{
"base_url": firstServer.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
originalOrder := antigravityBaseURLFallbackOrder
defer func() { antigravityBaseURLFallbackOrder = originalOrder }()
antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
return []string{firstServer.URL, secondServer.URL}
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("Execute() returned empty payload")
}
if firstCount != 2 {
t.Fatalf("first server request count = %d, want 2", firstCount)
}
if secondCount != 1 {
t.Fatalf("second server request count = %d, want 1", secondCount)
}
}
func TestAntigravityExecute_DoesNotDirectInjectCreditsWhenFlagDisabled(t *testing.T) {
resetAntigravityCreditsRetryState()
t.Cleanup(resetAntigravityCreditsRetryState)
var requestBodies []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
requestBodies = append(requestBodies, string(body))
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
}))
defer server.Close()
exec := NewAntigravityExecutor(&config.Config{
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: false},
})
auth := &cliproxyauth.Auth{
ID: "auth-flag-disabled",
Attributes: map[string]string{
"base_url": server.URL,
},
Metadata: map[string]any{
"access_token": "token",
"project_id": "project-1",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
},
}
markAntigravityPreferCredits(auth, "gemini-2.5-flash", time.Now(), nil)
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gemini-2.5-flash",
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatAntigravity,
})
if err == nil {
t.Fatal("Execute() error = nil, want 429")
}
if len(requestBodies) != 1 {
t.Fatalf("request count = %d, want 1", len(requestBodies))
}
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
t.Fatalf("request unexpectedly used enabledCreditTypes with flag disabled: %s", requestBodies[0])
}
}

View File

@@ -22,6 +22,8 @@ import (
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -44,6 +46,10 @@ type ClaudeExecutor struct {
// Previously "proxy_" was used but this is a detectable fingerprint difference.
const claudeToolPrefix = ""
// Anthropic-compatible upstreams may reject or even crash when Claude models
// omit max_tokens. Prefer registered model metadata before using a fallback.
const defaultModelMaxTokens = 1024
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
func (e *ClaudeExecutor) Identifier() string { return "claude" }
@@ -86,7 +92,7 @@ func (e *ClaudeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -101,8 +107,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
baseURL = "https://api.anthropic.com"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
@@ -125,8 +131,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body = ensureModelMaxTokens(body, baseModel)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
@@ -153,6 +160,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
if experimentalCCHSigningEnabled(e.cfg, auth) {
bodyForUpstream = signAnthropicMessagesBody(bodyForUpstream)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
@@ -166,7 +176,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -178,33 +188,33 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
@@ -213,7 +223,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -226,19 +236,19 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}()
data, err := io.ReadAll(decodedBody)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if stream {
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
}
} else {
reporter.publish(ctx, parseClaudeUsage(data))
reporter.Publish(ctx, helps.ParseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
@@ -269,8 +279,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
baseURL = "https://api.anthropic.com"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
originalPayloadSource := req.Payload
@@ -291,8 +301,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body = ensureModelMaxTokens(body, baseModel)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
@@ -316,6 +327,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
if experimentalCCHSigningEnabled(e.cfg, auth) {
bodyForUpstream = signAnthropicMessagesBody(bodyForUpstream)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
@@ -329,7 +343,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -341,33 +355,33 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -376,7 +390,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -397,9 +411,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
scanner.Buffer(nil, 52_428_800) // 50MB
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
@@ -411,8 +425,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Payload: cloned}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
return
@@ -424,9 +438,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
@@ -446,8 +460,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -496,7 +510,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -508,32 +522,32 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
resp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
helps.LogWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -541,7 +555,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}
decodedBody, err := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
@@ -554,10 +568,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}()
data, err := io.ReadAll(decodedBody)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
@@ -793,10 +807,10 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
}
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
var deviceProfile claudeDeviceProfile
stabilizeDeviceProfile := helps.ClaudeDeviceProfileStabilizationEnabled(cfg)
var deviceProfile helps.ClaudeDeviceProfile
if stabilizeDeviceProfile {
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
deviceProfile = helps.ResolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
}
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
@@ -864,9 +878,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
}
util.ApplyCustomHeadersFromAttrs(r, attrs)
if stabilizeDeviceProfile {
applyClaudeDeviceProfileHeaders(r, deviceProfile)
helps.ApplyClaudeDeviceProfileHeaders(r, deviceProfile)
} else {
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
helps.ApplyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
}
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
// may override it with a user-configured value. Compressed SSE breaks the line
@@ -893,7 +907,7 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
}
func checkSystemInstructions(payload []byte) []byte {
return checkSystemInstructionsWithMode(payload, false)
return checkSystemInstructionsWithSigningMode(payload, false, false)
}
func isClaudeOAuthToken(apiKey string) bool {
@@ -1037,7 +1051,7 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
if prefix == "" {
return line
}
payload := jsonPayload(line)
payload := helps.JSONPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
}
@@ -1115,43 +1129,14 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bo
return cloakMode, strictMode, sensitiveWords, cacheUserID
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
if cfg == nil || auth == nil {
return nil
}
apiKey, baseURL := claudeCreds(auth)
if apiKey == "" {
return nil
}
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
// Match by API key
if strings.EqualFold(cfgKey, apiKey) {
// If baseURL is specified, also check it
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
continue
}
return entry.Cloak
}
}
return nil
}
// injectFakeUserID generates and injects a fake user ID into the request metadata.
// When useCache is false, a new user ID is generated for every call.
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
generateID := func() string {
if useCache {
return cachedUserID(apiKey)
return helps.CachedUserID(apiKey)
}
return generateFakeUserID()
return helps.GenerateFakeUserID()
}
metadata := gjson.GetBytes(payload, "metadata")
@@ -1161,7 +1146,7 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
}
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
if existingUserID == "" || !isValidUserID(existingUserID) {
if existingUserID == "" || !helps.IsValidUserID(existingUserID) {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
}
return payload
@@ -1170,29 +1155,36 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
// generateBillingHeader creates the x-anthropic-billing-header text block that
// real Claude Code prepends to every system prompt array.
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=cli; cch=<hash>;
func generateBillingHeader(payload []byte) string {
// Generate a deterministic cch hash from the payload content (system + messages + tools).
// Real Claude Code uses a 5-char hex hash that varies per request.
h := sha256.Sum256(payload)
cch := hex.EncodeToString(h[:])[:5]
func generateBillingHeader(payload []byte, experimentalCCHSigning bool) string {
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
buildBytes := make([]byte, 2)
_, _ = rand.Read(buildBytes)
buildHash := hex.EncodeToString(buildBytes)[:3]
if experimentalCCHSigning {
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=00000;", buildHash)
}
// Generate a deterministic cch hash from the payload content (system + messages + tools).
// Real Claude Code uses a 5-char hex hash that varies per request.
h := sha256.Sum256(payload)
cch := hex.EncodeToString(h[:])[:5]
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
}
// checkSystemInstructionsWithMode injects Claude Code-style system blocks:
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
return checkSystemInstructionsWithSigningMode(payload, strictMode, false)
}
// checkSystemInstructionsWithSigningMode injects Claude Code-style system blocks:
//
// system[0]: billing header (no cache_control)
// system[1]: agent identifier (no cache_control)
// system[2..]: user system messages (cache_control added when missing)
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
func checkSystemInstructionsWithSigningMode(payload []byte, strictMode bool, experimentalCCHSigning bool) []byte {
system := gjson.GetBytes(payload, "system")
billingText := generateBillingHeader(payload)
billingText := generateBillingHeader(payload, experimentalCCHSigning)
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
// No cache_control on the agent block. It is a cloaking artifact with zero cache
// value (the last system block is what actually triggers caching of all system content).
@@ -1247,51 +1239,41 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
clientUserAgent := getClientUserAgent(ctx)
useExperimentalCCHSigning := experimentalCCHSigningEnabled(cfg, auth)
// Get cloak config from ClaudeKey configuration
cloakCfg := resolveClaudeKeyCloakConfig(cfg, auth)
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
// Determine cloak settings
var cloakMode string
var strictMode bool
var sensitiveWords []string
var cacheUserID bool
cloakMode := attrMode
strictMode := attrStrict
sensitiveWords := attrWords
cacheUserID := attrCache
if cloakCfg != nil {
cloakMode = cloakCfg.Mode
strictMode = cloakCfg.StrictMode
sensitiveWords = cloakCfg.SensitiveWords
if mode := strings.TrimSpace(cloakCfg.Mode); mode != "" {
cloakMode = mode
}
if cloakCfg.StrictMode {
strictMode = true
}
if len(cloakCfg.SensitiveWords) > 0 {
sensitiveWords = cloakCfg.SensitiveWords
}
if cloakCfg.CacheUserID != nil {
cacheUserID = *cloakCfg.CacheUserID
}
}
// Fallback to auth attributes if no config found
if cloakMode == "" {
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
cloakMode = attrMode
if !strictMode {
strictMode = attrStrict
}
if len(sensitiveWords) == 0 {
sensitiveWords = attrWords
}
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
cacheUserID = attrCache
}
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
_, _, _, attrCache := getCloakConfigFromAuth(auth)
cacheUserID = attrCache
}
// Determine if cloaking should be applied
if !shouldCloak(cloakMode, clientUserAgent) {
if !helps.ShouldCloak(cloakMode, clientUserAgent) {
return payload
}
// Skip system instructions for claude-3-5-haiku models
if !strings.HasPrefix(model, "claude-3-5-haiku") {
payload = checkSystemInstructionsWithMode(payload, strictMode)
payload = checkSystemInstructionsWithSigningMode(payload, strictMode, useExperimentalCCHSigning)
}
// Inject fake user ID
@@ -1299,8 +1281,8 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
// Apply sensitive word obfuscation
if len(sensitiveWords) > 0 {
matcher := buildSensitiveWordMatcher(sensitiveWords)
payload = obfuscateSensitiveWords(payload, matcher)
matcher := helps.BuildSensitiveWordMatcher(sensitiveWords)
payload = helps.ObfuscateSensitiveWords(payload, matcher)
}
return payload
@@ -1310,7 +1292,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
// According to Anthropic's documentation, cache prefixes are created in order: tools -> system -> messages.
// This function adds cache_control to:
// 1. The LAST tool in the tools array (caches all tool definitions)
// 2. The LAST element in the system array (caches system prompt)
// 2. The LAST system prompt element
// 3. The SECOND-TO-LAST user turn (caches conversation history for multi-turn)
//
// Up to 4 cache breakpoints are allowed per request. Tools, System, and Messages are INDEPENDENT breakpoints.
@@ -1880,3 +1862,26 @@ func injectSystemCacheControl(payload []byte) []byte {
return payload
}
func ensureModelMaxTokens(body []byte, modelID string) []byte {
if len(body) == 0 || !gjson.ValidBytes(body) {
return body
}
if maxTokens := gjson.GetBytes(body, "max_tokens"); maxTokens.Exists() {
return body
}
for _, provider := range registry.GetGlobalRegistry().GetModelProviders(strings.TrimSpace(modelID)) {
if strings.EqualFold(provider, "claude") {
maxTokens := defaultModelMaxTokens
if info := registry.GetGlobalRegistry().GetModelInfo(strings.TrimSpace(modelID), "claude"); info != nil && info.MaxCompletionTokens > 0 {
maxTokens = info.MaxCompletionTokens
}
body, _ = sjson.SetBytes(body, "max_tokens", maxTokens)
return body
}
}
return body
}

View File

@@ -4,9 +4,11 @@ import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"sync"
"testing"
@@ -14,7 +16,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
xxHash64 "github.com/pierrec/xxHash/xxHash64"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -23,9 +28,7 @@ import (
)
func resetClaudeDeviceProfileCache() {
claudeDeviceProfileCacheMu.Lock()
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu.Unlock()
helps.ResetClaudeDeviceProfileCache()
}
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
@@ -338,7 +341,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
var pauseOnce sync.Once
var releaseOnce sync.Once
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
helps.ClaudeDeviceProfileBeforeCandidateStore = func(candidate helps.ClaudeDeviceProfile) {
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
return
}
@@ -346,13 +349,13 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
<-releaseLow
}
t.Cleanup(func() {
claudeDeviceProfileBeforeCandidateStore = nil
helps.ClaudeDeviceProfileBeforeCandidateStore = nil
releaseOnce.Do(func() { close(releaseLow) })
})
lowResultCh := make(chan claudeDeviceProfile, 1)
lowResultCh := make(chan helps.ClaudeDeviceProfile, 1)
go func() {
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
lowResultCh <- helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
@@ -367,7 +370,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
t.Fatal("timed out waiting for lower candidate to pause before storing")
}
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
highResult := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.75.0"},
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
@@ -398,7 +401,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
}
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
cached := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"curl/8.7.1"},
}, cfg)
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
@@ -564,7 +567,7 @@ func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *tes
})
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
}
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
@@ -591,14 +594,14 @@ func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallbac
})
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
}
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
if claudeDeviceProfileStabilizationEnabled(nil) {
if helps.ClaudeDeviceProfileStabilizationEnabled(nil) {
t.Fatal("expected nil config to default to disabled stabilization")
}
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
if helps.ClaudeDeviceProfileStabilizationEnabled(&config.Config{}) {
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
}
}
@@ -796,8 +799,6 @@ func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
}
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
resetUserIDCache()
var userIDs []string
var requestModels []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -857,15 +858,13 @@ func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
if userIDs[0] != userIDs[1] {
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
}
if !isValidUserID(userIDs[0]) {
if !helps.IsValidUserID(userIDs[0]) {
t.Fatalf("user_id %q is not valid", userIDs[0])
}
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
}
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
resetUserIDCache()
var userIDs []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
@@ -903,7 +902,7 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
if userIDs[0] == userIDs[1] {
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
}
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
if !helps.IsValidUserID(userIDs[0]) || !helps.IsValidUserID(userIDs[1]) {
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
}
}
@@ -1183,6 +1182,83 @@ func testClaudeExecutorInvalidCompressedErrorBody(
}
}
func TestEnsureModelMaxTokens_UsesRegisteredMaxCompletionTokens(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-max-completion-tokens-client"
modelID := "test-claude-max-completion-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
MaxCompletionTokens: 4096,
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-max-completion-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 4096 {
t.Fatalf("max_tokens = %d, want %d", got, 4096)
}
}
func TestEnsureModelMaxTokens_DefaultsMissingValue(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-default-max-tokens-client"
modelID := "test-claude-default-max-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-default-max-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != defaultModelMaxTokens {
t.Fatalf("max_tokens = %d, want %d", got, defaultModelMaxTokens)
}
}
func TestEnsureModelMaxTokens_PreservesExplicitValue(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-claude-preserve-max-tokens-client"
modelID := "test-claude-preserve-max-tokens-model"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
ID: modelID,
Type: "claude",
OwnedBy: "anthropic",
Object: "model",
Created: time.Now().Unix(),
MaxCompletionTokens: 4096,
UserDefined: true,
}})
defer reg.UnregisterClient(clientID)
input := []byte(`{"model":"test-claude-preserve-max-tokens-model","max_tokens":2048,"messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, modelID)
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 2048 {
t.Fatalf("max_tokens = %d, want %d", got, 2048)
}
}
func TestEnsureModelMaxTokens_SkipsUnregisteredModel(t *testing.T) {
input := []byte(`{"model":"test-claude-unregistered-model","messages":[{"role":"user","content":"hi"}]}`)
out := ensureModelMaxTokens(input, "test-claude-unregistered-model")
if gjson.GetBytes(out, "max_tokens").Exists() {
t.Fatalf("max_tokens should remain unset, got %s", gjson.GetBytes(out, "max_tokens").Raw)
}
}
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
// requests use Accept-Encoding: identity so the upstream cannot respond with a
// compressed SSE body that would silently break the line scanner.
@@ -1340,6 +1416,35 @@ func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
}
}
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
// detects zstd-compressed content via magic bytes even when Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
_, _ = enc.Write([]byte(plaintext))
_ = enc.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
// plain text untouched when Content-Encoding is absent and no magic bytes match.
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
@@ -1411,77 +1516,6 @@ func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T)
}
}
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
// path's enforced identity encoding.
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
var gotEncoding string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
// Inject Accept-Encoding via the custom header attribute mechanism.
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
}
}
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
// Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
_, _ = enc.Write([]byte(plaintext))
_ = enc.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
// the Content-Encoding header. This closes the gap left by PR #1771, which only
@@ -1565,6 +1599,45 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
}
}
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies that the
// streaming executor enforces Accept-Encoding: identity regardless of auth.Attributes override.
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
var gotEncoding string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
}
}
// Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
@@ -1648,3 +1721,115 @@ func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
}
}
func TestClaudeExecutor_ExperimentalCCHSigningDisabledByDefaultKeepsLegacyHeader(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected request body to be captured")
}
billingHeader := gjson.GetBytes(seenBody, "system.0.text").String()
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
t.Fatalf("system.0.text = %q, want billing header", billingHeader)
}
if strings.Contains(billingHeader, "cch=00000;") {
t.Fatalf("legacy mode should not forward cch placeholder, got %q", billingHeader)
}
}
func TestClaudeExecutor_ExperimentalCCHSigningOptInSignsFinalBody(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{
ClaudeKey: []config.ClaudeKey{{
APIKey: "key-123",
BaseURL: server.URL,
ExperimentalCCHSigning: true,
}},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
const messageText = "please keep literal cch=00000 in this message"
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"please keep literal cch=00000 in this message"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected request body to be captured")
}
if got := gjson.GetBytes(seenBody, "messages.0.content.0.text").String(); got != messageText {
t.Fatalf("message text = %q, want %q", got, messageText)
}
billingPattern := regexp.MustCompile(`(x-anthropic-billing-header:[^"]*?\bcch=)([0-9a-f]{5})(;)`)
match := billingPattern.FindSubmatch(seenBody)
if match == nil {
t.Fatalf("expected signed billing header in body: %s", string(seenBody))
}
actualCCH := string(match[2])
unsignedBody := billingPattern.ReplaceAll(seenBody, []byte(`${1}00000${3}`))
wantCCH := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, 0x6E52736AC806831E)&0xFFFFF)
if actualCCH != wantCCH {
t.Fatalf("cch = %q, want %q\nbody: %s", actualCCH, wantCCH, string(seenBody))
}
}
func TestApplyCloaking_PreservesConfiguredStrictModeAndSensitiveWordsWhenModeOmitted(t *testing.T) {
cfg := &config.Config{
ClaudeKey: []config.ClaudeKey{{
APIKey: "key-123",
Cloak: &config.CloakConfig{
StrictMode: true,
SensitiveWords: []string{"proxy"},
},
}},
}
auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "key-123"}}
payload := []byte(`{"system":"proxy rules","messages":[{"role":"user","content":[{"type":"text","text":"proxy access"}]}]}`)
out := applyCloaking(context.Background(), cfg, auth, payload, "claude-3-5-sonnet-20241022", "key-123")
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("expected strict mode to keep only injected system blocks, got %d", len(blocks))
}
if got := gjson.GetBytes(out, "messages.0.content.0.text").String(); !strings.Contains(got, "\u200B") {
t.Fatalf("expected configured sensitive word obfuscation to apply, got %q", got)
}
}

View File

@@ -0,0 +1,81 @@
package executor
import (
"fmt"
"regexp"
"strings"
xxHash64 "github.com/pierrec/xxHash/xxHash64"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const claudeCCHSeed uint64 = 0x6E52736AC806831E
var claudeBillingHeaderCCHPattern = regexp.MustCompile(`\bcch=([0-9a-f]{5});`)
func signAnthropicMessagesBody(body []byte) []byte {
billingHeader := gjson.GetBytes(body, "system.0.text").String()
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
return body
}
if !claudeBillingHeaderCCHPattern.MatchString(billingHeader) {
return body
}
unsignedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(billingHeader, "cch=00000;")
unsignedBody, err := sjson.SetBytes(body, "system.0.text", unsignedBillingHeader)
if err != nil {
return body
}
cch := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, claudeCCHSeed)&0xFFFFF)
signedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(unsignedBillingHeader, "cch="+cch+";")
signedBody, err := sjson.SetBytes(unsignedBody, "system.0.text", signedBillingHeader)
if err != nil {
return unsignedBody
}
return signedBody
}
func resolveClaudeKeyConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.ClaudeKey {
if cfg == nil || auth == nil {
return nil
}
apiKey, baseURL := claudeCreds(auth)
if apiKey == "" {
return nil
}
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if !strings.EqualFold(cfgKey, apiKey) {
continue
}
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
continue
}
return entry
}
return nil
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
entry := resolveClaudeKeyConfig(cfg, auth)
if entry == nil {
return nil
}
return entry.Cloak
}
func experimentalCCHSigningEnabled(cfg *config.Config, auth *cliproxyauth.Auth) bool {
entry := resolveClaudeKeyConfig(cfg, auth)
return entry != nil && entry.ExperimentalCCHSigning
}

View File

@@ -1,125 +0,0 @@
package executor
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/google/uuid"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type codexContinuity struct {
Key string
Source string
}
func metadataString(meta map[string]any, key string) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[key]
if !ok || raw == nil {
return ""
}
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
func principalString(raw any) string {
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case fmt.Stringer:
return strings.TrimSpace(v.String())
default:
return strings.TrimSpace(fmt.Sprintf("%v", raw))
}
}
func resolveCodexContinuity(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) codexContinuity {
if promptCacheKey := strings.TrimSpace(gjson.GetBytes(req.Payload, "prompt_cache_key").String()); promptCacheKey != "" {
return codexContinuity{Key: promptCacheKey, Source: "prompt_cache_key"}
}
if executionSession := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); executionSession != "" {
return codexContinuity{Key: executionSession, Source: "execution_session"}
}
if ginCtx := ginContextFrom(ctx); ginCtx != nil {
if ginCtx.Request != nil {
if v := strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")); v != "" {
return codexContinuity{Key: v, Source: "idempotency_key"}
}
}
if v, exists := ginCtx.Get("apiKey"); exists && v != nil {
if trimmed := principalString(v); trimmed != "" {
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+trimmed)).String(), Source: "client_principal"}
}
}
}
if auth != nil {
if authID := strings.TrimSpace(auth.ID); authID != "" {
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:"+authID)).String(), Source: "auth_id"}
}
}
return codexContinuity{}
}
func applyCodexContinuityBody(rawJSON []byte, continuity codexContinuity) []byte {
if continuity.Key == "" {
return rawJSON
}
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", continuity.Key)
return rawJSON
}
func applyCodexContinuityHeaders(headers http.Header, continuity codexContinuity) {
if headers == nil || continuity.Key == "" {
return
}
headers.Set("session_id", continuity.Key)
}
func logCodexRequestDiagnostics(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, headers http.Header, body []byte, continuity codexContinuity) {
if !log.IsLevelEnabled(log.DebugLevel) {
return
}
entry := logWithRequestID(ctx)
authID := ""
authFile := ""
if auth != nil {
authID = strings.TrimSpace(auth.ID)
authFile = strings.TrimSpace(auth.FileName)
}
selectedAuthID := metadataString(opts.Metadata, cliproxyexecutor.SelectedAuthMetadataKey)
executionSessionID := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey)
entry.Debugf(
"codex request diagnostics auth_id=%s selected_auth_id=%s auth_file=%s exec_session=%s continuity_source=%s session_id=%s prompt_cache_key=%s prompt_cache_retention=%s store=%t has_instructions=%t reasoning_effort=%s reasoning_summary=%s chatgpt_account_id=%t originator=%s model=%s source_format=%s",
authID,
selectedAuthID,
authFile,
executionSessionID,
continuity.Source,
strings.TrimSpace(headers.Get("session_id")),
gjson.GetBytes(body, "prompt_cache_key").String(),
gjson.GetBytes(body, "prompt_cache_retention").String(),
gjson.GetBytes(body, "store").Bool(),
gjson.GetBytes(body, "instructions").Exists(),
gjson.GetBytes(body, "reasoning.effort").String(),
gjson.GetBytes(body, "reasoning.summary").String(),
strings.TrimSpace(headers.Get("Chatgpt-Account-Id")) != "",
strings.TrimSpace(headers.Get("Originator")),
req.Model,
opts.SourceFormat.String(),
)
}

View File

@@ -13,6 +13,7 @@ import (
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -28,8 +29,8 @@ import (
)
const (
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexOriginator = "codex_cli_rs"
codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
codexOriginator = "codex-tui"
)
var dataTag = []byte("data:")
@@ -73,7 +74,7 @@ func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -88,8 +89,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -106,31 +107,29 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
body = normalizeCodexInstructions(body)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -141,10 +140,10 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -152,20 +151,20 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
@@ -178,8 +177,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
continue
}
if detail, ok := parseCodexUsage(line); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseCodexUsage(line); ok {
reporter.Publish(ctx, detail)
}
var param any
@@ -199,8 +198,8 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai-response")
@@ -217,25 +216,25 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "stream")
body = normalizeCodexInstructions(body)
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -246,10 +245,10 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -257,22 +256,22 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
reporter.ensurePublished(ctx)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
reporter.EnsurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -290,8 +289,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -308,30 +307,28 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "model", baseModel)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
body = normalizeCodexInstructions(body)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
if err != nil {
return nil, err
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -343,24 +340,24 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, readErr := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
return nil, readErr
}
appendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = newCodexStatusErr(httpResp.StatusCode, data)
return nil, err
}
@@ -377,13 +374,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if bytes.HasPrefix(line, dataTag) {
data := bytes.TrimSpace(line[5:])
if gjson.GetBytes(data, "type").String() == "response.completed" {
if detail, ok := parseCodexUsage(data); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseCodexUsage(data); ok {
reporter.Publish(ctx, detail)
}
}
}
@@ -394,8 +391,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -420,9 +417,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "stream", false)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
body = normalizeCodexInstructions(body)
enc, err := tokenizerForCodexModel(baseModel)
if err != nil {
@@ -600,41 +595,43 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
return auth, nil
}
func (e *CodexExecutor) cacheHelper(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, url string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) (*http.Request, codexContinuity, error) {
var cache codexCache
continuity := codexContinuity{}
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
var cache helps.CodexCache
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
var ok bool
if cache, ok = getCodexCache(key); !ok {
cache = codexCache{
if cache, ok = helps.GetCodexCache(key); !ok {
cache = helps.CodexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
setCodexCache(key, cache)
helps.SetCodexCache(key, cache)
}
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
}
} else if from == "openai-response" {
promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key")
if promptCacheKey.Exists() {
cache.ID = promptCacheKey.String()
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
}
} else if from == "openai" {
continuity = resolveCodexContinuity(ctx, auth, req, opts)
cache.ID = continuity.Key
if apiKey := strings.TrimSpace(helps.APIKeyFromContext(ctx)); apiKey != "" {
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
}
}
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rawJSON))
if err != nil {
return nil, continuity, err
return nil, err
}
applyCodexContinuityHeaders(httpReq.Header, continuity)
return httpReq, continuity, nil
if cache.ID != "" {
httpReq.Header.Set("Session_id", cache.ID)
}
return httpReq, nil
}
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
@@ -646,13 +643,19 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
ginHeaders = ginCtx.Request.Header
}
if ginHeaders.Get("X-Codex-Beta-Features") != "" {
r.Header.Set("X-Codex-Beta-Features", ginHeaders.Get("X-Codex-Beta-Features"))
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
misc.EnsureHeader(r.Header, ginHeaders, "session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if strings.Contains(r.Header.Get("User-Agent"), "Mac OS") {
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
}
if stream {
r.Header.Set("Accept", "text/event-stream")
} else {
@@ -697,6 +700,14 @@ func newCodexStatusErr(statusCode int, body []byte) statusErr {
return err
}
func normalizeCodexInstructions(body []byte) []byte {
instructions := gjson.GetBytes(body, "instructions")
if !instructions.Exists() || instructions.Type == gjson.Null {
body, _ = sjson.SetBytes(body, "instructions", "")
}
return body
}
func isCodexModelCapacityError(errorBody []byte) bool {
if len(errorBody) == 0 {
return false

View File

@@ -8,7 +8,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
@@ -28,7 +27,7 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
}
url := "https://example.com/responses"
httpReq, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
@@ -43,14 +42,14 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
if gotKey != expectedKey {
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
}
if gotSession := httpReq.Header.Get("session_id"); gotSession != expectedKey {
t.Fatalf("session_id = %q, want %q", gotSession, expectedKey)
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != "" {
t.Fatalf("Conversation_id = %q, want empty", gotConversation)
}
if got := httpReq.Header.Get("Conversation_id"); got != "" {
t.Fatalf("Conversation_id = %q, want empty", got)
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
}
httpReq2, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error (second call): %v", err)
}
@@ -63,118 +62,3 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
}
}
func TestCodexExecutorCacheHelper_OpenAIResponses_PreservesPromptCacheRetention(t *testing.T) {
executor := &CodexExecutor{}
url := "https://example.com/responses"
req := cliproxyexecutor.Request{
Model: "gpt-5.3-codex",
Payload: []byte(`{"model":"gpt-5.3-codex","prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
}
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true,"prompt_cache_retention":"persistent"}`)
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai-response"), url, req, cliproxyexecutor.Options{}, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
body, err := io.ReadAll(httpReq.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "cache-key-1" {
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
}
if got := gjson.GetBytes(body, "prompt_cache_retention").String(); got != "persistent" {
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
}
if got := httpReq.Header.Get("session_id"); got != "cache-key-1" {
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
}
if got := httpReq.Header.Get("Conversation_id"); got != "" {
t.Fatalf("Conversation_id = %q, want empty", got)
}
}
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_UsesExecutionSessionForContinuity(t *testing.T) {
executor := &CodexExecutor{}
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
req := cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4"}`),
}
opts := cliproxyexecutor.Options{Metadata: map[string]any{cliproxyexecutor.ExecutionSessionMetadataKey: "exec-session-1"}}
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai"), "https://example.com/responses", req, opts, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
body, err := io.ReadAll(httpReq.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "exec-session-1" {
t.Fatalf("prompt_cache_key = %q, want %q", got, "exec-session-1")
}
if got := httpReq.Header.Get("session_id"); got != "exec-session-1" {
t.Fatalf("session_id = %q, want %q", got, "exec-session-1")
}
}
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_FallsBackToStableAuthID(t *testing.T) {
executor := &CodexExecutor{}
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
req := cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4"}`),
}
auth := &cliproxyauth.Auth{ID: "codex-auth-1", Provider: "codex"}
httpReq, _, err := executor.cacheHelper(context.Background(), auth, sdktranslator.FromString("openai"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
body, err := io.ReadAll(httpReq.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
expected := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:codex-auth-1")).String()
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != expected {
t.Fatalf("prompt_cache_key = %q, want %q", got, expected)
}
if got := httpReq.Header.Get("session_id"); got != expected {
t.Fatalf("session_id = %q, want %q", got, expected)
}
}
func TestCodexExecutorCacheHelper_ClaudePreservesCacheContinuity(t *testing.T) {
executor := &CodexExecutor{}
req := cliproxyexecutor.Request{
Model: "claude-3-7-sonnet",
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
}
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
httpReq, continuity, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("claude"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
if continuity.Key == "" {
t.Fatal("continuity.Key = empty, want non-empty")
}
body, err := io.ReadAll(httpReq.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != continuity.Key {
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
}
if got := httpReq.Header.Get("session_id"); got != continuity.Key {
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
}
}

View File

@@ -0,0 +1,79 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorCompactAddsDefaultInstructions(t *testing.T) {
cases := []struct {
name string
payload string
}{
{
name: "missing instructions",
payload: `{"model":"gpt-5.4","input":"hello"}`,
},
{
name: "null instructions",
payload: `{"model":"gpt-5.4","instructions":null,"input":"hello"}`,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}`))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(tc.payload),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Alt: "responses/compact",
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/responses/compact" {
t.Fatalf("path = %q, want %q", gotPath, "/responses/compact")
}
if !gjson.GetBytes(gotBody, "instructions").Exists() {
t.Fatalf("expected instructions in compact request body, got %s", string(gotBody))
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
if string(resp.Payload) != `{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}` {
t.Fatalf("payload = %s", string(resp.Payload))
}
})
}
}

View File

@@ -0,0 +1,123 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorExecuteNormalizesNullInstructions(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/responses" {
t.Fatalf("path = %q, want %q", gotPath, "/responses")
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
}
func TestCodexExecutorExecuteStreamNormalizesNullInstructions(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Stream: true,
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for range result.Chunks {
}
if gotPath != "/responses" {
t.Fatalf("path = %q, want %q", gotPath, "/responses")
}
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
}
if gjson.GetBytes(gotBody, "instructions").String() != "" {
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
}
}
func TestCodexExecutorCountTokensTreatsNullInstructionsAsEmpty(t *testing.T) {
executor := NewCodexExecutor(&config.Config{})
nullResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
})
if err != nil {
t.Fatalf("CountTokens(null) error: %v", err)
}
emptyResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4","instructions":"","input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
})
if err != nil {
t.Fatalf("CountTokens(empty) error: %v", err)
}
if string(nullResp.Payload) != string(emptyResp.Payload) {
t.Fatalf("token count payload mismatch:\nnull=%s\nempty=%s", string(nullResp.Payload), string(emptyResp.Payload))
}
}

View File

@@ -15,10 +15,12 @@ import (
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -44,10 +46,18 @@ const (
type CodexWebsocketsExecutor struct {
*CodexExecutor
sessMu sync.Mutex
store *codexWebsocketSessionStore
}
type codexWebsocketSessionStore struct {
mu sync.Mutex
sessions map[string]*codexWebsocketSession
}
var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{
sessions: make(map[string]*codexWebsocketSession),
}
type codexWebsocketSession struct {
sessionID string
@@ -71,7 +81,7 @@ type codexWebsocketSession struct {
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
return &CodexWebsocketsExecutor{
CodexExecutor: NewCodexExecutor(cfg),
sessions: make(map[string]*codexWebsocketSession),
store: globalCodexWebsocketSessionStore,
}
}
@@ -155,8 +165,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -173,11 +183,12 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
@@ -189,7 +200,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
return resp, err
}
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
@@ -208,8 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
wsReqBody := buildCodexWebsocketRequestBody(body)
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -223,12 +233,12 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if respHS != nil {
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr)
}
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.Execute(ctx, auth, req, opts)
@@ -236,7 +246,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if respHS != nil && respHS.StatusCode > 0 {
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
}
recordAPIResponseError(ctx, e.cfg, errDial)
helps.RecordAPIResponseError(ctx, e.cfg, errDial)
return resp, errDial
}
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
@@ -271,7 +281,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry == nil && connRetry != nil {
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -287,15 +297,15 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
wsReqBody = wsReqBodyRetry
} else {
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
recordAPIResponseError(ctx, e.cfg, errSendRetry)
helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry)
return resp, errSendRetry
}
} else {
recordAPIResponseError(ctx, e.cfg, errDialRetry)
helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry)
return resp, errDialRetry
}
} else {
recordAPIResponseError(ctx, e.cfg, errSend)
helps.RecordAPIResponseError(ctx, e.cfg, errSend)
return resp, errSend
}
}
@@ -306,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
if msgType != websocket.TextMessage {
@@ -315,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
}
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
continue
@@ -325,21 +335,21 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
if len(payload) == 0 {
continue
}
appendAPIResponseChunk(ctx, e.cfg, payload)
helps.AppendAPIResponseChunk(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok {
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
}
recordAPIResponseError(ctx, e.cfg, wsErr)
helps.RecordAPIResponseError(ctx, e.cfg, wsErr)
return resp, wsErr
}
payload = normalizeCodexWebsocketCompletion(payload)
eventType := gjson.GetBytes(payload, "type").String()
if eventType == "response.completed" {
if detail, ok := parseCodexUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseCodexUsage(payload); ok {
reporter.Publish(ctx, detail)
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, &param)
@@ -364,8 +374,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -376,8 +386,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
@@ -385,7 +395,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
return nil, err
}
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
@@ -403,8 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
wsReqBody := buildCodexWebsocketRequestBody(body)
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -420,12 +429,12 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
var upstreamHeaders http.Header
if respHS != nil {
upstreamHeaders = respHS.Header.Clone()
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyErr)
}
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
@@ -433,7 +442,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if respHS != nil && respHS.StatusCode > 0 {
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
}
recordAPIResponseError(ctx, e.cfg, errDial)
helps.RecordAPIResponseError(ctx, e.cfg, errDial)
if sess != nil {
sess.reqMu.Unlock()
}
@@ -452,20 +461,20 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
recordAPIResponseError(ctx, e.cfg, errSend)
helps.RecordAPIResponseError(ctx, e.cfg, errSend)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
// Retry once with a new websocket connection for the same execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry != nil || connRetry == nil {
recordAPIResponseError(ctx, e.cfg, errDialRetry)
helps.RecordAPIResponseError(ctx, e.cfg, errDialRetry)
sess.clearActive(readCh)
sess.reqMu.Unlock()
return nil, errDialRetry
}
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
@@ -477,7 +486,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
AuthValue: authValue,
})
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
recordAPIResponseError(ctx, e.cfg, errSendRetry)
helps.RecordAPIResponseError(ctx, e.cfg, errSendRetry)
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
sess.clearActive(readCh)
sess.reqMu.Unlock()
@@ -543,8 +552,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
terminateReason = "read_error"
terminateErr = errRead
recordAPIResponseError(ctx, e.cfg, errRead)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.PublishFailure(ctx)
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
return
}
@@ -553,8 +562,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
err = fmt.Errorf("codex websockets executor: unexpected binary message")
terminateReason = "unexpected_binary"
terminateErr = err
recordAPIResponseError(ctx, e.cfg, err)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, err)
reporter.PublishFailure(ctx)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
}
@@ -568,13 +577,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
if len(payload) == 0 {
continue
}
appendAPIResponseChunk(ctx, e.cfg, payload)
helps.AppendAPIResponseChunk(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok {
terminateReason = "upstream_error"
terminateErr = wsErr
recordAPIResponseError(ctx, e.cfg, wsErr)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, wsErr)
reporter.PublishFailure(ctx)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
}
@@ -585,8 +594,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
payload = normalizeCodexWebsocketCompletion(payload)
eventType := gjson.GetBytes(payload, "type").String()
if eventType == "response.completed" || eventType == "response.done" {
if detail, ok := parseCodexUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseCodexUsage(payload); ok {
reporter.Publish(ctx, detail)
}
}
@@ -762,43 +771,39 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
return parsed.String(), nil
}
func applyCodexPromptCacheHeaders(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) ([]byte, http.Header, codexContinuity) {
func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecutor.Request, rawJSON []byte) ([]byte, http.Header) {
headers := http.Header{}
if len(rawJSON) == 0 {
return rawJSON, headers, codexContinuity{}
return rawJSON, headers
}
var cache codexCache
continuity := codexContinuity{}
var cache helps.CodexCache
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
if cached, ok := getCodexCache(key); ok {
if cached, ok := helps.GetCodexCache(key); ok {
cache = cached
} else {
cache = codexCache{
cache = helps.CodexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
setCodexCache(key, cache)
helps.SetCodexCache(key, cache)
}
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
}
} else if from == "openai-response" {
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
cache.ID = promptCacheKey.String()
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
}
} else if from == "openai" {
continuity = resolveCodexContinuity(ctx, auth, req, opts)
cache.ID = continuity.Key
}
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
applyCodexContinuityHeaders(headers, continuity)
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
headers.Set("Conversation_id", cache.ID)
}
return rawJSON, headers, continuity
return rawJSON, headers
}
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
@@ -810,11 +815,11 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
}
var ginHeaders http.Header
if ginCtx := ginContextFrom(ctx); ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header.Clone()
}
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
_, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
@@ -830,8 +835,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
betaHeader = codexResponsesWebsocketBetaHeaderValue
}
headers.Set("OpenAI-Beta", betaHeader)
misc.EnsureHeader(headers, ginHeaders, "session_id", uuid.NewString())
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
}
headers.Del("User-Agent")
isAPIKey := false
if auth != nil && auth.Attributes != nil {
@@ -1059,16 +1066,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
if sessionID == "" {
return nil
}
e.sessMu.Lock()
defer e.sessMu.Unlock()
if e.sessions == nil {
e.sessions = make(map[string]*codexWebsocketSession)
if e == nil {
return nil
}
if sess, ok := e.sessions[sessionID]; ok && sess != nil {
store := e.store
if store == nil {
store = globalCodexWebsocketSessionStore
}
store.mu.Lock()
defer store.mu.Unlock()
if store.sessions == nil {
store.sessions = make(map[string]*codexWebsocketSession)
}
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
return sess
}
sess := &codexWebsocketSession{sessionID: sessionID}
e.sessions[sessionID] = sess
store.sessions[sessionID] = sess
return sess
}
@@ -1214,14 +1228,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
return
}
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
e.closeAllExecutionSessions("executor_replaced")
// Executor replacement can happen during hot reload (config/credential changes).
// Do not force-close upstream websocket sessions here, otherwise in-flight
// downstream websocket requests get interrupted.
return
}
e.sessMu.Lock()
sess := e.sessions[sessionID]
delete(e.sessions, sessionID)
e.sessMu.Unlock()
store := e.store
if store == nil {
store = globalCodexWebsocketSessionStore
}
store.mu.Lock()
sess := store.sessions[sessionID]
delete(store.sessions, sessionID)
store.mu.Unlock()
e.closeExecutionSession(sess, "session_closed")
}
@@ -1231,15 +1251,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
return
}
e.sessMu.Lock()
sessions := make([]*codexWebsocketSession, 0, len(e.sessions))
for sessionID, sess := range e.sessions {
delete(e.sessions, sessionID)
store := e.store
if store == nil {
store = globalCodexWebsocketSessionStore
}
store.mu.Lock()
sessions := make([]*codexWebsocketSession, 0, len(store.sessions))
for sessionID, sess := range store.sessions {
delete(store.sessions, sessionID)
if sess != nil {
sessions = append(sessions, sess)
}
}
e.sessMu.Unlock()
store.mu.Unlock()
for i := range sessions {
e.closeExecutionSession(sessions[i], reason)
@@ -1247,6 +1271,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
}
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
closeCodexWebsocketSession(sess, reason)
}
func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) {
if sess == nil {
return
}
@@ -1287,6 +1315,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
}
// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions
// associated with the supplied auth ID.
func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) {
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
reason = strings.TrimSpace(reason)
if reason == "" {
reason = "auth_removed"
}
store := globalCodexWebsocketSessionStore
if store == nil {
return
}
type sessionItem struct {
sessionID string
sess *codexWebsocketSession
}
store.mu.Lock()
items := make([]sessionItem, 0, len(store.sessions))
for sessionID, sess := range store.sessions {
items = append(items, sessionItem{sessionID: sessionID, sess: sess})
}
store.mu.Unlock()
matches := make([]sessionItem, 0)
for i := range items {
sess := items[i].sess
if sess == nil {
continue
}
sess.connMu.Lock()
sessAuthID := strings.TrimSpace(sess.authID)
sess.connMu.Unlock()
if sessAuthID == authID {
matches = append(matches, items[i])
}
}
if len(matches) == 0 {
return
}
toClose := make([]*codexWebsocketSession, 0, len(matches))
store.mu.Lock()
for i := range matches {
current, ok := store.sessions[matches[i].sessionID]
if !ok || current == nil || current != matches[i].sess {
continue
}
delete(store.sessions, matches[i].sessionID)
toClose = append(toClose, current)
}
store.mu.Unlock()
for i := range toClose {
closeCodexWebsocketSession(toClose[i], reason)
}
}
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
// 1. The downstream transport is websocket, and
// 2. The selected auth enables websockets.

View File

@@ -0,0 +1,48 @@
package executor
import (
"testing"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) {
sessionID := "test-session-store-survives-replace"
globalCodexWebsocketSessionStore.mu.Lock()
delete(globalCodexWebsocketSessionStore.sessions, sessionID)
globalCodexWebsocketSessionStore.mu.Unlock()
exec1 := NewCodexWebsocketsExecutor(nil)
sess1 := exec1.getOrCreateSession(sessionID)
if sess1 == nil {
t.Fatalf("expected session to be created")
}
exec2 := NewCodexWebsocketsExecutor(nil)
sess2 := exec2.getOrCreateSession(sessionID)
if sess2 == nil {
t.Fatalf("expected session to be available across executors")
}
if sess1 != sess2 {
t.Fatalf("expected the same session instance across executors")
}
exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID)
globalCodexWebsocketSessionStore.mu.Lock()
_, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID]
globalCodexWebsocketSessionStore.mu.Unlock()
if !stillPresent {
t.Fatalf("expected session to remain after executor replacement close marker")
}
exec2.CloseExecutionSession(sessionID)
globalCodexWebsocketSessionStore.mu.Lock()
_, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID]
globalCodexWebsocketSessionStore.mu.Unlock()
if presentAfterClose {
t.Fatalf("expected session to be removed after explicit close")
}
}

View File

@@ -9,9 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -34,57 +32,14 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
}
}
func TestApplyCodexPromptCacheHeaders_PreservesPromptCacheRetention(t *testing.T) {
req := cliproxyexecutor.Request{
Model: "gpt-5-codex",
Payload: []byte(`{"prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
}
body := []byte(`{"model":"gpt-5-codex","stream":true,"prompt_cache_retention":"persistent"}`)
updatedBody, headers, _ := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("openai-response"), req, cliproxyexecutor.Options{}, body)
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != "cache-key-1" {
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
}
if got := gjson.GetBytes(updatedBody, "prompt_cache_retention").String(); got != "persistent" {
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
}
if got := headers.Get("session_id"); got != "cache-key-1" {
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
}
if got := headers.Get("Conversation_id"); got != "" {
t.Fatalf("Conversation_id = %q, want empty", got)
}
}
func TestApplyCodexPromptCacheHeaders_ClaudePreservesContinuity(t *testing.T) {
req := cliproxyexecutor.Request{
Model: "claude-3-7-sonnet",
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
}
body := []byte(`{"model":"gpt-5.4","stream":true}`)
updatedBody, headers, continuity := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("claude"), req, cliproxyexecutor.Options{}, body)
if continuity.Key == "" {
t.Fatal("continuity.Key = empty, want non-empty")
}
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != continuity.Key {
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
}
if got := headers.Get("session_id"); got != continuity.Key {
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
}
if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
@@ -142,8 +97,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
@@ -174,8 +129,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
if gotVal := got.Get("User-Agent"); gotVal != "" {
t.Fatalf("User-Agent = %s, want empty", gotVal)
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
@@ -200,8 +155,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
@@ -222,8 +177,8 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)

View File

@@ -0,0 +1,129 @@
package executor
import (
"context"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
}
func parseOpenAIUsage(data []byte) usage.Detail {
return helps.ParseOpenAIUsage(data)
}
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
return helps.ParseOpenAIStreamUsage(line)
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
return helps.ParseOpenAIUsage(data)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
return helps.ParseOpenAIStreamUsage(line)
}
func getTokenizer(model string) (tokenizer.Codec, error) {
return helps.TokenizerForModel(model)
}
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return helps.CountOpenAIChatTokens(enc, payload)
}
func countClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return helps.CountClaudeChatTokens(enc, payload)
}
func buildOpenAIUsageJSON(count int64) []byte {
return helps.BuildOpenAIUsageJSON(count)
}
type upstreamRequestLog = helps.UpstreamRequestLog
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
helps.RecordAPIRequest(ctx, cfg, info)
}
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
helps.RecordAPIResponseMetadata(ctx, cfg, status, headers)
}
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
helps.RecordAPIResponseError(ctx, cfg, err)
}
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
helps.AppendAPIResponseChunk(ctx, cfg, chunk)
}
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
return helps.PayloadRequestedModel(opts, fallback)
}
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
return helps.ApplyPayloadConfigWithRoot(cfg, model, protocol, root, payload, original, requestedModel)
}
func summarizeErrorBody(contentType string, body []byte) string {
return helps.SummarizeErrorBody(contentType, body)
}
func apiKeyFromContext(ctx context.Context) string {
return helps.APIKeyFromContext(ctx)
}
func tokenizerForModel(model string) (tokenizer.Codec, error) {
return helps.TokenizerForModel(model)
}
func collectOpenAIContent(content gjson.Result, segments *[]string) {
helps.CollectOpenAIContent(content, segments)
}
type usageReporter struct {
reporter *helps.UsageReporter
}
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
return &usageReporter{reporter: helps.NewUsageReporter(ctx, provider, model, auth)}
}
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
if r == nil || r.reporter == nil {
return
}
r.reporter.Publish(ctx, detail)
}
func (r *usageReporter) publishFailure(ctx context.Context) {
if r == nil || r.reporter == nil {
return
}
r.reporter.PublishFailure(ctx)
}
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
if r == nil || r.reporter == nil {
return
}
r.reporter.TrackFailure(ctx, errPtr)
}
func (r *usageReporter) ensurePublished(ctx context.Context) {
if r == nil || r.reporter == nil {
return
}
r.reporter.EnsurePublished(ctx)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -112,8 +113,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return resp, err
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
@@ -132,8 +133,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
requestedModel := payloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
action := "generateContent"
if req.Metadata != nil {
@@ -190,7 +191,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
@@ -204,7 +205,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
httpResp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
err = errDo
return resp, err
}
@@ -213,15 +214,15 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose)
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
reporter.publish(ctx, parseGeminiCLIUsage(data))
reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -230,7 +231,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 {
if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
@@ -245,7 +246,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
if len(lastBody) > 0 {
appendAPIResponseChunk(ctx, e.cfg, lastBody)
helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
}
if lastStatus == 0 {
lastStatus = 429
@@ -266,8 +267,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
return nil, err
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
@@ -286,8 +287,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
requestedModel := payloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
projectID := resolveGeminiProjectID(auth)
@@ -335,7 +336,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "text/event-stream")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
@@ -349,25 +350,25 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
httpResp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
err = errDo
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return nil, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 {
if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
@@ -394,9 +395,9 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiCLIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseGeminiCLIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param)
@@ -411,8 +412,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
return
@@ -420,13 +421,13 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
data, errRead := io.ReadAll(resp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errRead}
return
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param)
for i := range segments {
@@ -443,7 +444,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
if len(lastBody) > 0 {
appendAPIResponseChunk(ctx, e.cfg, lastBody)
helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
}
if lastStatus == 0 {
lastStatus = 429
@@ -516,7 +517,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, baseModel)
reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
@@ -530,17 +531,19 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
resp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
data, errRead := io.ReadAll(resp.Body)
_ = resp.Body.Close()
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if errClose := resp.Body.Close(); errClose != nil {
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
@@ -611,7 +614,7 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
}
ctxToken := ctx
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
}
@@ -707,7 +710,7 @@ func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any {
}
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
}
func cloneMap(in map[string]any) map[string]any {

View File

@@ -13,6 +13,7 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -85,7 +86,7 @@ func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -110,8 +111,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
// Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat
@@ -130,8 +131,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := "generateContent"
@@ -165,7 +166,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -177,10 +178,10 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -188,21 +189,21 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
log.Errorf("gemini executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -218,8 +219,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -237,8 +238,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
baseURL := resolveGeminiBaseURL(auth)
@@ -268,7 +269,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -280,17 +281,17 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini executor: close response body error: %v", errClose)
}
@@ -310,14 +311,14 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
filtered := FilterSSEUsageMetadata(line)
payload := jsonPayload(filtered)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
filtered := helps.FilterSSEUsageMetadata(line)
payload := helps.JSONPayload(filtered)
if len(payload) == 0 {
continue
}
if detail, ok := parseGeminiStreamUsage(payload); ok {
reporter.publish(ctx, detail)
if detail, ok := helps.ParseGeminiStreamUsage(payload); ok {
reporter.Publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param)
for i := range lines {
@@ -329,8 +330,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -381,7 +382,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -393,23 +394,27 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
resp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
defer func() { _ = resp.Body.Close() }()
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
data, err := io.ReadAll(resp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, helps.SummarizeErrorBody(resp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
}

View File

@@ -16,6 +16,7 @@ import (
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -227,7 +228,7 @@ func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -301,8 +302,8 @@ func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Aut
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
var body []byte
@@ -332,8 +333,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
}
@@ -369,7 +370,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -381,10 +382,10 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
@@ -392,21 +393,21 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
// For Imagen models, convert response to Gemini format before translation
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
@@ -427,8 +428,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -447,8 +448,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, false)
@@ -484,7 +485,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -496,10 +497,10 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
@@ -507,21 +508,21 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
@@ -532,8 +533,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -552,8 +553,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
@@ -588,7 +589,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -600,17 +601,17 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
@@ -630,9 +631,9 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
@@ -644,8 +645,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -656,8 +657,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -676,8 +677,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
@@ -712,7 +713,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -724,17 +725,17 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
@@ -754,9 +755,9 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
@@ -768,8 +769,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -819,7 +820,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -831,10 +832,10 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
@@ -842,19 +843,19 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
@@ -903,7 +904,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -915,10 +916,10 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
@@ -926,19 +927,19 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
@@ -1012,7 +1013,7 @@ func vertexBaseURL(location string) string {
}
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
// Use cloud-platform scope for Vertex AI.

View File

@@ -76,6 +76,9 @@ func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected responses-only registry model to use /responses")
}
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
t.Fatal("expected responses-only registry model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
@@ -83,15 +86,25 @@ func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *test
reg := registry.GetGlobalRegistry()
clientID := "github-copilot-test-client"
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
ID: "gpt-5.4",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
}})
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
{
ID: "gpt-5.4",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
{
ID: "gpt-5.4-mini",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
})
defer reg.UnregisterClient(clientID)
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
}
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {

View File

@@ -1,11 +1,11 @@
package executor
package helps
import (
"sync"
"time"
)
type codexCache struct {
type CodexCache struct {
ID string
Expire time.Time
}
@@ -13,7 +13,7 @@ type codexCache struct {
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
// Protected by codexCacheMu. Entries expire after 1 hour.
var (
codexCacheMap = make(map[string]codexCache)
codexCacheMap = make(map[string]CodexCache)
codexCacheMu sync.RWMutex
)
@@ -50,20 +50,20 @@ func purgeExpiredCodexCache() {
}
}
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
func getCodexCache(key string) (codexCache, bool) {
// GetCodexCache retrieves a cached entry, returning ok=false if not found or expired.
func GetCodexCache(key string) (CodexCache, bool) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.RLock()
cache, ok := codexCacheMap[key]
codexCacheMu.RUnlock()
if !ok || cache.Expire.Before(time.Now()) {
return codexCache{}, false
return CodexCache{}, false
}
return cache, true
}
// setCodexCache stores a cache entry.
func setCodexCache(key string, cache codexCache) {
// SetCodexCache stores a cache entry.
func SetCodexCache(key string, cache CodexCache) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.Lock()
codexCacheMap[key] = cache

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"crypto/sha256"
@@ -32,7 +32,7 @@ var (
claudeDeviceProfileCacheMu sync.RWMutex
claudeDeviceProfileCacheCleanupOnce sync.Once
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
ClaudeDeviceProfileBeforeCandidateStore func(ClaudeDeviceProfile)
)
type claudeCLIVersion struct {
@@ -63,29 +63,43 @@ func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
}
}
type claudeDeviceProfile struct {
type ClaudeDeviceProfile struct {
UserAgent string
PackageVersion string
RuntimeVersion string
OS string
Arch string
Version claudeCLIVersion
HasVersion bool
version claudeCLIVersion
hasVersion bool
}
type claudeDeviceProfileCacheEntry struct {
profile claudeDeviceProfile
profile ClaudeDeviceProfile
expire time.Time
}
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
func ClaudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
return false
}
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
}
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
func ResetClaudeDeviceProfileCache() {
claudeDeviceProfileCacheMu.Lock()
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu.Unlock()
}
func MapStainlessOS() string {
return mapStainlessOS()
}
func MapStainlessArch() string {
return mapStainlessArch()
}
func defaultClaudeDeviceProfile(cfg *config.Config) ClaudeDeviceProfile {
hdrDefault := func(cfgVal, fallback string) string {
if strings.TrimSpace(cfgVal) != "" {
return strings.TrimSpace(cfgVal)
@@ -98,7 +112,7 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
hd = cfg.ClaudeHeaderDefaults
}
profile := claudeDeviceProfile{
profile := ClaudeDeviceProfile{
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
@@ -106,8 +120,8 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
}
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
profile.Version = version
profile.HasVersion = true
profile.version = version
profile.hasVersion = true
}
return profile
}
@@ -162,17 +176,17 @@ func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
}
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
if candidate.UserAgent == "" || !candidate.HasVersion {
func shouldUpgradeClaudeDeviceProfile(candidate, current ClaudeDeviceProfile) bool {
if candidate.UserAgent == "" || !candidate.hasVersion {
return false
}
if current.UserAgent == "" || !current.HasVersion {
if current.UserAgent == "" || !current.hasVersion {
return true
}
return candidate.Version.Compare(current.Version) > 0
return candidate.version.Compare(current.version) > 0
}
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
func pinClaudeDeviceProfilePlatform(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
profile.OS = baseline.OS
profile.Arch = baseline.Arch
return profile
@@ -180,38 +194,38 @@ func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claud
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
// baseline platform and enforces the baseline software fingerprint as a floor.
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
func normalizeClaudeDeviceProfile(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
if profile.UserAgent == "" || !profile.hasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
profile.UserAgent = baseline.UserAgent
profile.PackageVersion = baseline.PackageVersion
profile.RuntimeVersion = baseline.RuntimeVersion
profile.Version = baseline.Version
profile.HasVersion = baseline.HasVersion
profile.version = baseline.version
profile.hasVersion = baseline.hasVersion
}
return profile
}
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (ClaudeDeviceProfile, bool) {
if headers == nil {
return claudeDeviceProfile{}, false
return ClaudeDeviceProfile{}, false
}
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
version, ok := parseClaudeCLIVersion(userAgent)
if !ok {
return claudeDeviceProfile{}, false
return ClaudeDeviceProfile{}, false
}
baseline := defaultClaudeDeviceProfile(cfg)
profile := claudeDeviceProfile{
profile := ClaudeDeviceProfile{
UserAgent: userAgent,
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
Version: version,
HasVersion: true,
version: version,
hasVersion: true,
}
return profile, true
}
@@ -263,7 +277,7 @@ func purgeExpiredClaudeDeviceProfiles() {
claudeDeviceProfileCacheMu.Unlock()
}
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
func ResolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) ClaudeDeviceProfile {
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
@@ -283,8 +297,8 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
claudeDeviceProfileCacheMu.RUnlock()
if hasCandidate {
if claudeDeviceProfileBeforeCandidateStore != nil {
claudeDeviceProfileBeforeCandidateStore(candidate)
if ClaudeDeviceProfileBeforeCandidateStore != nil {
ClaudeDeviceProfileBeforeCandidateStore(candidate)
}
claudeDeviceProfileCacheMu.Lock()
@@ -324,7 +338,7 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
return baseline
}
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
func ApplyClaudeDeviceProfileHeaders(r *http.Request, profile ClaudeDeviceProfile) {
if r == nil {
return
}
@@ -344,7 +358,7 @@ func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfil
r.Header.Set("X-Stainless-Arch", profile.Arch)
}
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
func ApplyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
if r == nil {
return
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"regexp"
@@ -18,9 +18,9 @@ type SensitiveWordMatcher struct {
regex *regexp.Regexp
}
// buildSensitiveWordMatcher compiles a regex from the word list.
// BuildSensitiveWordMatcher compiles a regex from the word list.
// Words are sorted by length (longest first) for proper matching.
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
func BuildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
if len(words) == 0 {
return nil
}
@@ -81,9 +81,9 @@ func (m *SensitiveWordMatcher) obfuscateText(text string) string {
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
}
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
// ObfuscateSensitiveWords processes the payload and obfuscates sensitive words
// in system blocks and message content.
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
func ObfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
if matcher == nil || matcher.regex == nil {
return payload
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"crypto/rand"
@@ -28,9 +28,17 @@ func isValidUserID(userID string) bool {
return userIDPattern.MatchString(userID)
}
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
func GenerateFakeUserID() string {
return generateFakeUserID()
}
func IsValidUserID(userID string) bool {
return isValidUserID(userID)
}
// ShouldCloak determines if request should be cloaked based on config and client User-Agent.
// Returns true if cloaking should be applied.
func shouldCloak(cloakMode string, userAgent string) bool {
func ShouldCloak(cloakMode string, userAgent string) bool {
switch strings.ToLower(cloakMode) {
case "always":
return true

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"bytes"
@@ -24,8 +24,8 @@ const (
apiResponseKey = "API_RESPONSE"
)
// upstreamRequestLog captures the outbound upstream request details for logging.
type upstreamRequestLog struct {
// UpstreamRequestLog captures the outbound upstream request details for logging.
type UpstreamRequestLog struct {
URL string
Method string
Headers http.Header
@@ -49,8 +49,8 @@ type upstreamAttempt struct {
errorWritten bool
}
// recordAPIRequest stores the upstream request metadata in Gin context for request logging.
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
// RecordAPIRequest stores the upstream request metadata in Gin context for request logging.
func RecordAPIRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
if cfg == nil || !cfg.RequestLog {
return
}
@@ -96,8 +96,8 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
updateAggregatedRequest(ginCtx, attempts)
}
// recordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
// RecordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
func RecordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
if cfg == nil || !cfg.RequestLog {
return
}
@@ -122,8 +122,8 @@ func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status i
updateAggregatedResponse(ginCtx, attempts)
}
// recordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
// RecordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
func RecordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
if cfg == nil || !cfg.RequestLog || err == nil {
return
}
@@ -147,8 +147,8 @@ func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error)
updateAggregatedResponse(ginCtx, attempts)
}
// appendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
// AppendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
@@ -285,7 +285,7 @@ func writeHeaders(builder *strings.Builder, headers http.Header) {
}
}
func formatAuthInfo(info upstreamRequestLog) string {
func formatAuthInfo(info UpstreamRequestLog) string {
var parts []string
if trimmed := strings.TrimSpace(info.Provider); trimmed != "" {
parts = append(parts, fmt.Sprintf("provider=%s", trimmed))
@@ -321,7 +321,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
return strings.Join(parts, ", ")
}
func summarizeErrorBody(contentType string, body []byte) string {
func SummarizeErrorBody(contentType string, body []byte) string {
isHTML := strings.Contains(strings.ToLower(contentType), "text/html")
if !isHTML {
trimmed := bytes.TrimSpace(bytes.ToLower(body))
@@ -379,7 +379,7 @@ func extractJSONErrorMessage(body []byte) string {
// logWithRequestID returns a logrus Entry with request_id field populated from context.
// If no request ID is found in context, it returns the standard logger.
func logWithRequestID(ctx context.Context) *log.Entry {
func LogWithRequestID(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"encoding/json"
@@ -11,12 +11,12 @@ import (
"github.com/tidwall/sjson"
)
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied. Defaults are checked
// against the original payload when provided. requestedModel carries the client-visible
// model name before alias resolution so payload rules can target aliases precisely.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
if cfg == nil || len(payload) == 0 {
return payload
}
@@ -244,7 +244,7 @@ func payloadRawValue(value any) ([]byte, bool) {
}
}
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
fallback = strings.TrimSpace(fallback)
if len(opts.Metadata) == 0 {
return fallback

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"context"
@@ -19,7 +19,7 @@ var (
httpClientCacheMutex sync.RWMutex
)
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// NewProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// 1. Use auth.ProxyURL if configured (highest priority)
// 2. Use cfg.ProxyURL if auth proxy is not configured
// 3. Use RoundTripper from context if neither are configured
@@ -34,7 +34,7 @@ var (
//
// Returns:
// - *http.Client: An HTTP client with configured proxy or transport
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
// Priority 1: Use auth.ProxyURL if configured
var proxyURL string
if auth != nil {
@@ -46,23 +46,18 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
// Build cache key from proxy URL (empty string for no proxy)
cacheKey := proxyURL
// Check cache first
httpClientCacheMutex.RLock()
if cachedClient, ok := httpClientCache[cacheKey]; ok {
httpClientCacheMutex.RUnlock()
// Return a wrapper with the requested timeout but shared transport
if timeout > 0 {
return &http.Client{
Transport: cachedClient.Transport,
Timeout: timeout,
// If we have a proxy URL configured, try cache first to reuse TCP/TLS connections.
if proxyURL != "" {
httpClientCacheMutex.RLock()
if cachedClient, ok := httpClientCache[proxyURL]; ok {
httpClientCacheMutex.RUnlock()
if timeout > 0 {
return &http.Client{Transport: cachedClient.Transport, Timeout: timeout}
}
return cachedClient
}
return cachedClient
httpClientCacheMutex.RUnlock()
}
httpClientCacheMutex.RUnlock()
// Create new client
httpClient := &http.Client{}
@@ -77,7 +72,7 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = transport
// Cache the client
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCache[proxyURL] = httpClient
httpClientCacheMutex.Unlock()
return httpClient
}
@@ -90,13 +85,6 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = rt
}
// Cache the client for no-proxy case
if proxyURL == "" {
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
}
return httpClient
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"context"
@@ -13,7 +13,7 @@ import (
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
client := newProxyAwareHTTPClient(
client := NewProxyAwareHTTPClient(
context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"

View File

@@ -1,9 +1,7 @@
package executor
package helps
import (
"fmt"
"regexp"
"strconv"
"strings"
"sync"
@@ -11,100 +9,80 @@ import (
"github.com/tiktoken-go/tokenizer"
)
// tokenizerCache stores tokenizer instances to avoid repeated creation
// tokenizerCache stores tokenizer instances to avoid repeated creation.
var tokenizerCache sync.Map
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
type TokenizerWrapper struct {
Codec tokenizer.Codec
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
type adjustedTokenizer struct {
tokenizer.Codec
adjustmentFactor float64
}
// Count returns the token count with adjustment factor applied
func (tw *TokenizerWrapper) Count(text string) (int, error) {
func (tw *adjustedTokenizer) Count(text string) (int, error) {
count, err := tw.Codec.Count(text)
if err != nil {
return 0, err
}
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
return int(float64(count) * tw.AdjustmentFactor), nil
if tw.adjustmentFactor > 0 && tw.adjustmentFactor != 1.0 {
return int(float64(count) * tw.adjustmentFactor), nil
}
return count, nil
}
// getTokenizer returns a cached tokenizer for the given model.
// This improves performance by avoiding repeated tokenizer creation.
func getTokenizer(model string) (*TokenizerWrapper, error) {
// Check cache first
if cached, ok := tokenizerCache.Load(model); ok {
return cached.(*TokenizerWrapper), nil
// TokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
// For Claude-like models, it applies an adjustment factor since tiktoken may underestimate token counts.
func TokenizerForModel(model string) (tokenizer.Codec, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
if cached, ok := tokenizerCache.Load(sanitized); ok {
return cached.(tokenizer.Codec), nil
}
// Cache miss, create new tokenizer
wrapper, err := tokenizerForModel(model)
enc, err := tokenizerForModel(sanitized)
if err != nil {
return nil, err
}
// Store in cache (use LoadOrStore to handle race conditions)
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
return actual.(*TokenizerWrapper), nil
actual, _ := tokenizerCache.LoadOrStore(sanitized, enc)
return actual.(tokenizer.Codec), nil
}
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
func tokenizerForModel(sanitized string) (tokenizer.Codec, error) {
if sanitized == "" {
return tokenizer.Get(tokenizer.Cl100kBase)
}
// Claude models use cl100k_base with 1.1 adjustment factor
// because tiktoken may underestimate Claude's actual token count
// Claude models use cl100k_base with an adjustment factor because tiktoken may underestimate.
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
return &adjustedTokenizer{Codec: enc, adjustmentFactor: 1.1}, nil
}
var enc tokenizer.Codec
var err error
switch {
case sanitized == "":
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5.2"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5.1"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
return tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
enc, err = tokenizer.ForModel(tokenizer.GPT41)
return tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
return tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
enc, err = tokenizer.ForModel(tokenizer.GPT4)
return tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
return tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"):
enc, err = tokenizer.ForModel(tokenizer.O1)
return tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"):
enc, err = tokenizer.ForModel(tokenizer.O3)
return tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"):
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
return tokenizer.ForModel(tokenizer.O4Mini)
default:
enc, err = tokenizer.Get(tokenizer.O200kBase)
return tokenizer.Get(tokenizer.O200kBase)
}
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
}
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
// CountOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func CountOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
@@ -128,22 +106,15 @@ func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
return 0, nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
return int64(count), nil
}
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
// This handles Claude's message format with system, messages, and tools.
// Image tokens are estimated based on image dimensions when available.
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
// CountClaudeChatTokens approximates prompt tokens for Claude API chat payloads.
func CountClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
@@ -153,185 +124,25 @@ func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
root := gjson.ParseBytes(payload)
segments := make([]string, 0, 32)
imageTokens := 0
// Collect system prompt (can be string or array of content blocks)
collectClaudeSystem(root.Get("system"), &segments)
// Collect messages
collectClaudeMessages(root.Get("messages"), &segments)
// Collect tools
collectClaudeContent(root.Get("system"), &segments, &imageTokens)
collectClaudeMessages(root.Get("messages"), &segments, &imageTokens)
collectClaudeTools(root.Get("tools"), &segments)
joined := strings.TrimSpace(strings.Join(segments, "\n"))
if joined == "" {
return 0, nil
return int64(imageTokens), nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
return int64(count + imageTokens), nil
}
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
// extractImageTokens extracts image token estimates from placeholder text.
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
func extractImageTokens(text string) int {
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
total := 0
for _, match := range matches {
if len(match) > 1 {
if tokens, err := strconv.Atoi(match[1]); err == nil {
total += tokens
}
}
}
return total
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image)
return 1000
}
tokens := int(width * height / 750)
// Apply bounds
if tokens < 85 {
tokens = 85
}
if tokens > 1590 {
tokens = 1590
}
return tokens
}
// collectClaudeSystem extracts text from Claude's system field.
// System can be a string or an array of content blocks.
func collectClaudeSystem(system gjson.Result, segments *[]string) {
if !system.Exists() {
return
}
if system.Type == gjson.String {
addIfNotEmpty(segments, system.String())
return
}
if system.IsArray() {
system.ForEach(func(_, block gjson.Result) bool {
blockType := block.Get("type").String()
if blockType == "text" || blockType == "" {
addIfNotEmpty(segments, block.Get("text").String())
}
// Also handle plain string blocks
if block.Type == gjson.String {
addIfNotEmpty(segments, block.String())
}
return true
})
}
}
// collectClaudeMessages extracts text from Claude's messages array.
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments)
return true
})
}
// collectClaudeContent extracts text from Claude's content field.
// Content can be a string or an array of content blocks.
// For images, estimates token count based on dimensions when available.
func collectClaudeContent(content gjson.Result, segments *[]string) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
// Estimate image tokens based on dimensions if available
source := part.Get("source")
if source.Exists() {
width := source.Get("width").Float()
height := source.Get("height").Float()
if width > 0 && height > 0 {
tokens := estimateImageTokens(width, height)
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
} else {
// No dimensions available, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
} else {
// No source info, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
// For unknown types, try to extract any text content
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
}
}
// collectClaudeTools extracts text from Claude's tools array.
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
func buildOpenAIUsageJSON(count int64) []byte {
// BuildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
func BuildOpenAIUsageJSON(count int64) []byte {
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
}
@@ -390,6 +201,10 @@ func collectOpenAIContent(content gjson.Result, segments *[]string) {
}
}
func CollectOpenAIContent(content gjson.Result, segments *[]string) {
collectOpenAIContent(content, segments)
}
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
if !calls.Exists() || !calls.IsArray() {
return
@@ -487,6 +302,98 @@ func appendToolPayload(tool gjson.Result, segments *[]string) {
}
}
func collectClaudeMessages(messages gjson.Result, segments *[]string, imageTokens *int) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments, imageTokens)
return true
})
}
func collectClaudeContent(content gjson.Result, segments *[]string, imageTokens *int) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
source := part.Get("source")
width := source.Get("width").Float()
height := source.Get("height").Float()
if imageTokens != nil {
*imageTokens += estimateImageTokens(width, height)
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments, imageTokens)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
return
}
if content.Type == gjson.JSON {
addIfNotEmpty(segments, content.Raw)
}
}
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image).
return 1000
}
tokens := int(width * height / 750)
if tokens < 85 {
return 85
}
if tokens > 1590 {
return 1590
}
return tokens
}
func addIfNotEmpty(segments *[]string, value string) {
if segments == nil {
return

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"bytes"
@@ -15,7 +15,7 @@ import (
"github.com/tidwall/sjson"
)
type usageReporter struct {
type UsageReporter struct {
provider string
model string
authID string
@@ -26,9 +26,9 @@ type usageReporter struct {
once sync.Once
}
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
apiKey := apiKeyFromContext(ctx)
reporter := &usageReporter{
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
apiKey := APIKeyFromContext(ctx)
reporter := &UsageReporter{
provider: provider,
model: model,
requestedAt: time.Now(),
@@ -42,24 +42,24 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox
return reporter
}
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
func (r *UsageReporter) Publish(ctx context.Context, detail usage.Detail) {
r.publishWithOutcome(ctx, detail, false)
}
func (r *usageReporter) publishFailure(ctx context.Context) {
func (r *UsageReporter) PublishFailure(ctx context.Context) {
r.publishWithOutcome(ctx, usage.Detail{}, true)
}
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
func (r *UsageReporter) TrackFailure(ctx context.Context, errPtr *error) {
if r == nil || errPtr == nil {
return
}
if *errPtr != nil {
r.publishFailure(ctx)
r.PublishFailure(ctx)
}
}
func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
func (r *UsageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
if r == nil {
return
}
@@ -81,7 +81,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
// It is safe to call multiple times; only the first call wins due to once.Do.
// This is used to ensure request counting even when upstream responses do not
// include any usage fields (tokens), especially for streaming paths.
func (r *usageReporter) ensurePublished(ctx context.Context) {
func (r *UsageReporter) EnsurePublished(ctx context.Context) {
if r == nil {
return
}
@@ -90,7 +90,7 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
})
}
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
if r == nil {
return usage.Record{Detail: detail, Failed: failed}
}
@@ -108,7 +108,7 @@ func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Reco
}
}
func (r *usageReporter) latency() time.Duration {
func (r *UsageReporter) latency() time.Duration {
if r == nil || r.requestedAt.IsZero() {
return 0
}
@@ -119,7 +119,7 @@ func (r *usageReporter) latency() time.Duration {
return latency
}
func apiKeyFromContext(ctx context.Context) string {
func APIKeyFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
@@ -184,7 +184,7 @@ func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
return ""
}
func parseCodexUsage(data []byte) (usage.Detail, bool) {
func ParseCodexUsage(data []byte) (usage.Detail, bool) {
usageNode := gjson.ParseBytes(data).Get("response.usage")
if !usageNode.Exists() {
return usage.Detail{}, false
@@ -203,7 +203,7 @@ func parseCodexUsage(data []byte) (usage.Detail, bool) {
return detail, true
}
func parseOpenAIUsage(data []byte) usage.Detail {
func ParseOpenAIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
@@ -238,7 +238,7 @@ func parseOpenAIUsage(data []byte) usage.Detail {
return detail
}
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
func ParseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -247,59 +247,40 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
if !usageNode.Exists() {
return usage.Detail{}, false
}
inputNode := usageNode.Get("prompt_tokens")
if !inputNode.Exists() {
inputNode = usageNode.Get("input_tokens")
}
outputNode := usageNode.Get("completion_tokens")
if !outputNode.Exists() {
outputNode = usageNode.Get("output_tokens")
}
detail := usage.Detail{
InputTokens: usageNode.Get("prompt_tokens").Int(),
OutputTokens: usageNode.Get("completion_tokens").Int(),
InputTokens: inputNode.Int(),
OutputTokens: outputNode.Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
if !cached.Exists() {
cached = usageNode.Get("input_tokens_details.cached_tokens")
}
if cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
if !reasoning.Exists() {
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
}
if reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail, true
}
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
}
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
}
return parseOpenAIResponsesUsageDetail(usageNode)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
usageNode := gjson.GetBytes(payload, "usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
return parseOpenAIResponsesUsageDetail(usageNode), true
}
func parseClaudeUsage(data []byte) usage.Detail {
func ParseClaudeUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
@@ -317,7 +298,7 @@ func parseClaudeUsage(data []byte) usage.Detail {
return detail
}
func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
func ParseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -352,7 +333,7 @@ func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
return detail
}
func parseGeminiCLIUsage(data []byte) usage.Detail {
func ParseGeminiCLIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
if !node.Exists() {
@@ -364,7 +345,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiUsage(data []byte) usage.Detail {
func ParseGeminiUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("usageMetadata")
if !node.Exists() {
@@ -376,7 +357,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
func ParseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -391,7 +372,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
return parseGeminiFamilyUsageDetail(node), true
}
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
func ParseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -406,7 +387,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
return parseGeminiFamilyUsageDetail(node), true
}
func parseAntigravityUsage(data []byte) usage.Detail {
func ParseAntigravityUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
if !node.Exists() {
@@ -421,7 +402,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
return parseGeminiFamilyUsageDetail(node)
}
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
func ParseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
@@ -590,6 +571,10 @@ func isStopChunkWithoutUsage(jsonBytes []byte) bool {
return !hasUsageMetadata(jsonBytes)
}
func JSONPayload(line []byte) []byte {
return jsonPayload(line)
}
func jsonPayload(line []byte) []byte {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 {

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"testing"
@@ -9,7 +9,7 @@ import (
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
detail := parseOpenAIUsage(data)
detail := ParseOpenAIUsage(data)
if detail.InputTokens != 1 {
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
}
@@ -29,7 +29,7 @@ func TestParseOpenAIUsageChatCompletions(t *testing.T) {
func TestParseOpenAIUsageResponses(t *testing.T) {
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
detail := parseOpenAIUsage(data)
detail := ParseOpenAIUsage(data)
if detail.InputTokens != 10 {
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
}
@@ -48,7 +48,7 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
}
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
reporter := &usageReporter{
reporter := &UsageReporter{
provider: "openai",
model: "gpt-5.4",
requestedAt: time.Now().Add(-1500 * time.Millisecond),

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"crypto/sha256"
@@ -49,7 +49,7 @@ func userIDCacheKey(apiKey string) string {
return hex.EncodeToString(sum[:])
}
func cachedUserID(apiKey string) string {
func CachedUserID(apiKey string) string {
if apiKey == "" {
return generateFakeUserID()
}

View File

@@ -1,4 +1,4 @@
package executor
package helps
import (
"testing"
@@ -14,8 +14,8 @@ func resetUserIDCache() {
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-1")
first := CachedUserID("api-key-1")
second := CachedUserID("api-key-1")
if first == "" {
t.Fatal("expected generated user_id to be non-empty")
@@ -28,7 +28,7 @@ func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
resetUserIDCache()
expiredID := cachedUserID("api-key-expired")
expiredID := CachedUserID("api-key-expired")
cacheKey := userIDCacheKey("api-key-expired")
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
@@ -37,7 +37,7 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
}
userIDCacheMu.Unlock()
newID := cachedUserID("api-key-expired")
newID := CachedUserID("api-key-expired")
if newID == expiredID {
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
}
@@ -49,8 +49,8 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-2")
first := CachedUserID("api-key-1")
second := CachedUserID("api-key-2")
if first == second {
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
@@ -61,7 +61,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
resetUserIDCache()
key := "api-key-renew"
id := cachedUserID(key)
id := CachedUserID(key)
cacheKey := userIDCacheKey(key)
soon := time.Now()
@@ -72,7 +72,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
}
userIDCacheMu.Unlock()
if refreshed := cachedUserID(key); refreshed != id {
if refreshed := CachedUserID(key); refreshed != id {
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/uuid"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -66,7 +67,7 @@ func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -86,8 +87,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
baseURL = iflowauth.DefaultAPIBaseURL
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
@@ -106,8 +107,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
}
body = preserveReasoningContentInMessages(body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -122,7 +123,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -134,10 +135,10 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -145,25 +146,25 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
log.Errorf("iflow executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
// Ensure usage is recorded even if upstream omits usage metadata.
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
@@ -189,8 +190,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
baseURL = iflowauth.DefaultAPIBaseURL
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
@@ -214,8 +215,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -230,7 +231,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -242,21 +243,21 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, _ := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("iflow executor: close response body error: %v", errClose)
}
appendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
return nil, err
}
@@ -275,9 +276,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
@@ -285,12 +286,12 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
// Guarantee a usage record exists even if the stream never emitted usage data.
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -303,17 +304,17 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
enc, err := tokenizerForModel(baseModel)
enc, err := helps.TokenizerForModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, body)
count, err := helps.CountOpenAIChatTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
usageJSON := helps.BuildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translated}, nil
}

View File

@@ -15,6 +15,7 @@ import (
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -60,7 +61,7 @@ func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -76,8 +77,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
token := kimiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
@@ -100,8 +101,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = normalizeKimiToolMessageLinks(body)
if err != nil {
return resp, err
@@ -119,7 +120,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -131,10 +132,10 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -142,21 +143,21 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
log.Errorf("kimi executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
@@ -176,8 +177,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
baseModel := thinking.ParseSuffix(req.Model).ModelName
token := kimiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
@@ -204,8 +205,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if err != nil {
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = normalizeKimiToolMessageLinks(body)
if err != nil {
return nil, err
@@ -223,7 +224,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -235,17 +236,17 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("kimi executor: close response body error: %v", errClose)
}
@@ -265,9 +266,9 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
@@ -279,8 +280,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -65,15 +66,15 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
@@ -95,8 +96,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
if opts.Alt == "responses/compact" {
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
translated = updated
@@ -129,7 +130,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -141,10 +142,10 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -152,23 +153,23 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
// Ensure we at least record the request even if upstream doesn't return usage
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
// Translate response back to source format when needed
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
@@ -179,8 +180,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
@@ -197,8 +198,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
@@ -232,7 +233,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -244,17 +245,17 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
@@ -274,9 +275,9 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if len(line) == 0 {
continue
@@ -294,12 +295,12 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
// Ensure we record the request if no usage chunk was ever seen
reporter.ensurePublished(ctx)
reporter.EnsurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
@@ -318,17 +319,17 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
return cliproxyexecutor.Response{}, err
}
enc, err := tokenizerForModel(modelForCounting)
enc, err := helps.TokenizerForModel(modelForCounting)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, translated)
count, err := helps.CountOpenAIChatTokens(enc, translated)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
usageJSON := helps.BuildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
}

View File

@@ -13,6 +13,7 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -23,7 +24,7 @@ import (
)
const (
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenUserAgent = "QwenCode/0.13.2 (darwin; arm64)"
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
qwenRateLimitWindow = time.Minute // sliding window duration
)
@@ -154,7 +155,7 @@ func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int,
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
cooldown := timeUntilNextDay()
retryAfter = &cooldown
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
helps.LogWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
}
return errCode, retryAfter
}
@@ -202,7 +203,7 @@ func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -217,7 +218,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return resp, err
}
@@ -228,8 +229,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
baseURL = "https://portal.qwen.ai/v1"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
@@ -247,8 +248,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -261,7 +262,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -273,10 +274,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
@@ -284,23 +285,23 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
log.Errorf("qwen executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
@@ -320,7 +321,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return nil, err
}
@@ -331,8 +332,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
baseURL = "https://portal.qwen.ai/v1"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
@@ -357,8 +358,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
}
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -371,7 +372,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
@@ -383,19 +384,19 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
@@ -415,9 +416,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
@@ -429,8 +430,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
@@ -449,17 +450,17 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
modelName = baseModel
}
enc, err := tokenizerForModel(modelName)
enc, err := helps.TokenizerForModel(modelName)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, body)
count, err := helps.CountOpenAIChatTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
usageJSON := helps.BuildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translated}, nil
}
@@ -508,16 +509,15 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", qwenUserAgent)
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
r.Header["X-DashScope-UserAgent"] = []string{qwenUserAgent}
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
r.Header.Set("Sec-Fetch-Mode", "cors")
r.Header.Set("X-Stainless-Lang", "js")
r.Header.Set("X-Stainless-Arch", "arm64")
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
r.Header["X-DashScope-CacheControl"] = []string{"enable"}
r.Header.Set("X-Stainless-Retry-Count", "0")
r.Header.Set("X-Stainless-Os", "MacOS")
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
r.Header["X-DashScope-AuthType"] = []string{"qwen-oauth"}
r.Header.Set("X-Stainless-Runtime", "node")
if stream {

View File

@@ -330,32 +330,45 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
}
// Reorder parts for 'model' role to ensure thinking block is first
// Reorder parts for 'model' role:
// 1. Thinking parts first (Antigravity API requirement)
// 2. Regular parts (text, inlineData, etc.)
// 3. FunctionCall parts last
//
// Moving functionCall parts to the end prevents tool_use↔tool_result
// pairing breakage: the Antigravity API internally splits model messages
// at functionCall boundaries. If a text part follows a functionCall, the
// split creates an extra assistant turn between tool_use and tool_result,
// which Claude rejects with "tool_use ids were found without tool_result
// blocks immediately after".
if role == "model" {
partsResult := gjson.GetBytes(clientContentJSON, "parts")
if partsResult.IsArray() {
parts := partsResult.Array()
var thinkingParts []gjson.Result
var otherParts []gjson.Result
for _, part := range parts {
if part.Get("thought").Bool() {
thinkingParts = append(thinkingParts, part)
} else {
otherParts = append(otherParts, part)
}
}
if len(thinkingParts) > 0 {
firstPartIsThinking := parts[0].Get("thought").Bool()
if !firstPartIsThinking || len(thinkingParts) > 1 {
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
if len(parts) > 1 {
var thinkingParts []gjson.Result
var regularParts []gjson.Result
var functionCallParts []gjson.Result
for _, part := range parts {
if part.Get("thought").Bool() {
thinkingParts = append(thinkingParts, part)
} else if part.Get("functionCall").Exists() {
functionCallParts = append(functionCallParts, part)
} else {
regularParts = append(regularParts, part)
}
for _, p := range otherParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
}
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
}
for _, p := range regularParts {
newParts = append(newParts, p.Value())
}
for _, p := range functionCallParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
}
}
}

View File

@@ -361,6 +361,167 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
}
}
func TestConvertClaudeRequestToAntigravity_ReorderTextAfterFunctionCall(t *testing.T) {
// Bug: text part after tool_use in an assistant message causes Antigravity
// to split at functionCall boundary, creating an extra assistant turn that
// breaks tool_use↔tool_result adjacency (upstream issue #989).
// Fix: reorder parts so functionCall comes last.
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me check..."},
{
"type": "tool_use",
"id": "call_abc",
"name": "Read",
"input": {"file": "test.go"}
},
{"type": "text", "text": "Reading the file now"}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_abc",
"content": "file content"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 3 {
t.Fatalf("Expected 3 parts, got %d", len(parts))
}
// Text parts should come before functionCall
if parts[0].Get("text").String() != "Let me check..." {
t.Errorf("Expected first text part first, got %s", parts[0].Raw)
}
if parts[1].Get("text").String() != "Reading the file now" {
t.Errorf("Expected second text part second, got %s", parts[1].Raw)
}
if !parts[2].Get("functionCall").Exists() {
t.Errorf("Expected functionCall last, got %s", parts[2].Raw)
}
if parts[2].Get("functionCall.name").String() != "Read" {
t.Errorf("Expected functionCall name 'Read', got '%s'", parts[2].Get("functionCall.name").String())
}
}
func TestConvertClaudeRequestToAntigravity_ReorderParallelFunctionCalls(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Reading both files."},
{
"type": "tool_use",
"id": "call_1",
"name": "Read",
"input": {"file": "a.go"}
},
{"type": "text", "text": "And this one too."},
{
"type": "tool_use",
"id": "call_2",
"name": "Read",
"input": {"file": "b.go"}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 4 {
t.Fatalf("Expected 4 parts, got %d", len(parts))
}
if parts[0].Get("text").String() != "Reading both files." {
t.Errorf("Expected first text, got %s", parts[0].Raw)
}
if parts[1].Get("text").String() != "And this one too." {
t.Errorf("Expected second text, got %s", parts[1].Raw)
}
if parts[2].Get("functionCall.name").String() != "Read" || parts[2].Get("functionCall.id").String() != "call_1" {
t.Errorf("Expected fc1 third, got %s", parts[2].Raw)
}
if parts[3].Get("functionCall.name").String() != "Read" || parts[3].Get("functionCall.id").String() != "call_2" {
t.Errorf("Expected fc2 fourth, got %s", parts[3].Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ReorderThinkingAndTextBeforeFunctionCall(t *testing.T) {
cache.ClearSignatureCache("")
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think about this..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Before thinking"},
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
{
"type": "tool_use",
"id": "call_xyz",
"name": "Bash",
"input": {"command": "ls"}
},
{"type": "text", "text": "After tool call"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// contents.1 = assistant message (contents.0 = user)
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
if len(parts) != 4 {
t.Fatalf("Expected 4 parts, got %d", len(parts))
}
// Order: thinking → text → text → functionCall
if !parts[0].Get("thought").Bool() {
t.Error("First part should be thinking")
}
if parts[1].Get("functionCall").Exists() || parts[1].Get("thought").Bool() {
t.Errorf("Second part should be text, got %s", parts[1].Raw)
}
if parts[2].Get("functionCall").Exists() || parts[2].Get("thought").Bool() {
t.Errorf("Third part should be text, got %s", parts[2].Raw)
}
if !parts[3].Get("functionCall").Exists() {
t.Errorf("Last part should be functionCall, got %s", parts[3].Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",

View File

@@ -284,12 +284,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
}
// Process the output array for content and function calls
var toolCalls [][]byte
outputResult := responseResult.Get("output")
if outputResult.IsArray() {
outputArray := outputResult.Array()
var contentText string
var reasoningText string
var toolCalls [][]byte
for _, outputItem := range outputArray {
outputType := outputItem.Get("type").String()
@@ -367,8 +367,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if statusResult := responseResult.Get("status"); statusResult.Exists() {
status := statusResult.String()
if status == "completed" {
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
}
}

View File

@@ -201,6 +201,7 @@ var zhStrings = map[string]string{
"usage_output": "输出",
"usage_cached": "缓存",
"usage_reasoning": "思考",
"usage_time": "时间",
// ── Logs ──
"logs_title": "📋 日志",
@@ -352,6 +353,7 @@ var enStrings = map[string]string{
"usage_output": "Output",
"usage_cached": "Cached",
"usage_reasoning": "Reasoning",
"usage_time": "Time",
// ── Logs ──
"logs_title": "📋 Logs",

View File

@@ -248,6 +248,9 @@ func (m usageTabModel) renderContent() string {
// Token type breakdown from details
sb.WriteString(m.renderTokenBreakdown(stats))
// Latency breakdown from details
sb.WriteString(m.renderLatencyBreakdown(stats))
}
}
}
@@ -308,6 +311,57 @@ func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
}
// renderLatencyBreakdown aggregates latency_ms from model details and displays avg/min/max.
func (m usageTabModel) renderLatencyBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var totalLatency int64
var count int
var minLatency, maxLatency int64
first := true
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
latencyMs := int64(getFloat(dm, "latency_ms"))
if latencyMs <= 0 {
continue
}
totalLatency += latencyMs
count++
if first {
minLatency = latencyMs
maxLatency = latencyMs
first = false
} else {
if latencyMs < minLatency {
minLatency = latencyMs
}
if latencyMs > maxLatency {
maxLatency = latencyMs
}
}
}
if count == 0 {
return ""
}
avgLatency := totalLatency / int64(count)
return fmt.Sprintf(" │ %s: avg %dms min %dms max %dms\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_time")),
avgLatency, minLatency, maxLatency)
}
// renderBarChart renders a simple ASCII horizontal bar chart.
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
if maxBarWidth < 10 {

View File

@@ -0,0 +1,134 @@
package tui
import (
"strings"
"testing"
)
func TestRenderLatencyBreakdown(t *testing.T) {
tests := []struct {
name string
modelStats map[string]any
wantEmpty bool
wantContains string
}{
{
name: "no details",
modelStats: map[string]any{},
wantEmpty: true,
},
{
name: "empty details",
modelStats: map[string]any{
"details": []any{},
},
wantEmpty: true,
},
{
name: "details with zero latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(0),
},
},
},
wantEmpty: true,
},
{
name: "single request with latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(1500),
},
},
},
wantEmpty: false,
wantContains: "avg 1500ms min 1500ms max 1500ms",
},
{
name: "multiple requests with varying latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(100),
},
map[string]any{
"latency_ms": float64(200),
},
map[string]any{
"latency_ms": float64(300),
},
},
},
wantEmpty: false,
wantContains: "avg 200ms min 100ms max 300ms",
},
{
name: "mixed valid and invalid latency values",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(500),
},
map[string]any{
"latency_ms": float64(0),
},
map[string]any{
"latency_ms": float64(1500),
},
},
},
wantEmpty: false,
wantContains: "avg 1000ms min 500ms max 1500ms",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := usageTabModel{}
result := m.renderLatencyBreakdown(tt.modelStats)
if tt.wantEmpty {
if result != "" {
t.Errorf("renderLatencyBreakdown() = %q, want empty string", result)
}
return
}
if result == "" {
t.Errorf("renderLatencyBreakdown() = empty, want non-empty string")
return
}
if tt.wantContains != "" && !strings.Contains(result, tt.wantContains) {
t.Errorf("renderLatencyBreakdown() = %q, want to contain %q", result, tt.wantContains)
}
})
}
}
func TestUsageTimeTranslations(t *testing.T) {
prevLocale := CurrentLocale()
t.Cleanup(func() {
SetLocale(prevLocale)
})
tests := []struct {
locale string
want string
}{
{locale: "en", want: "Time"},
{locale: "zh", want: "时间"},
}
for _, tt := range tests {
t.Run(tt.locale, func(t *testing.T) {
SetLocale(tt.locale)
if got := T("usage_time"); got != tt.want {
t.Fatalf("T(usage_time) = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -80,6 +80,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
}
if oldCfg.QuotaExceeded.AntigravityCredits != newCfg.QuotaExceeded.AntigravityCredits {
changes = append(changes, fmt.Sprintf("quota-exceeded.antigravity-credits: %t -> %t", oldCfg.QuotaExceeded.AntigravityCredits, newCfg.QuotaExceeded.AntigravityCredits))
}
if oldCfg.Routing.Strategy != newCfg.Routing.Strategy {
changes = append(changes, fmt.Sprintf("routing.strategy: %s -> %s", oldCfg.Routing.Strategy, newCfg.Routing.Strategy))

View File

@@ -229,7 +229,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
MaxRetryCredentials: 1,
MaxRetryInterval: 1,
WebsocketAuth: false,
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false, AntigravityCredits: false},
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
CodexKey: []config.CodexKey{{APIKey: "x1"}},
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
@@ -253,7 +253,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
MaxRetryCredentials: 3,
MaxRetryInterval: 3,
WebsocketAuth: true,
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true, AntigravityCredits: true},
ClaudeKey: []config.ClaudeKey{
{APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
{APIKey: "c2"},
@@ -297,6 +297,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
expectContains(t, details, "nonstream-keepalive-interval: 0 -> 5")
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
expectContains(t, details, "quota-exceeded.antigravity-credits: false -> true")
expectContains(t, details, "api-keys count: 1 -> 2")
expectContains(t, details, "claude-api-key count: 1 -> 2")
expectContains(t, details, "codex-api-key count: 1 -> 2")
@@ -320,7 +321,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
MaxRetryCredentials: 1,
MaxRetryInterval: 1,
WebsocketAuth: false,
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false, AntigravityCredits: false},
GeminiKey: []config.GeminiKey{
{APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}},
},
@@ -374,7 +375,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
MaxRetryCredentials: 3,
MaxRetryInterval: 3,
WebsocketAuth: true,
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true, AntigravityCredits: true},
GeminiKey: []config.GeminiKey{
{APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}},
},
@@ -437,6 +438,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
expectContains(t, changes, "ws-auth: false -> true")
expectContains(t, changes, "quota-exceeded.switch-project: false -> true")
expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true")
expectContains(t, changes, "quota-exceeded.antigravity-credits: false -> true")
expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)")
expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new")
expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new")

View File

@@ -1,6 +1,9 @@
package openai
import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
)
const (
openAIChatEndpoint = "/chat/completions"
@@ -12,6 +15,12 @@ func resolveEndpointOverride(modelName, requestedEndpoint string) (string, bool)
return "", false
}
info := registry.GetGlobalRegistry().GetModelInfo(modelName, "")
if info == nil {
baseModel := thinking.ParseSuffix(modelName).ModelName
if baseModel != "" && baseModel != modelName {
info = registry.GetGlobalRegistry().GetModelInfo(baseModel, "")
}
}
if info == nil || len(info.SupportedEndpoints) == 0 {
return "", false
}

View File

@@ -0,0 +1,29 @@
package openai
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func TestResolveEndpointOverride_StripsThinkingSuffix(t *testing.T) {
const clientID = "test-endpoint-compat-suffix"
reg := registry.GetGlobalRegistry()
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
{
ID: "test-gemini-chat-only",
SupportedEndpoints: []string{openAIChatEndpoint},
},
})
t.Cleanup(func() {
reg.UnregisterClient(clientID)
})
override, ok := resolveEndpointOverride("test-gemini-chat-only(high)", openAIResponsesEndpoint)
if !ok {
t.Fatalf("expected endpoint override to be resolved")
}
if override != openAIChatEndpoint {
t.Fatalf("override endpoint = %q, want %q", override, openAIChatEndpoint)
}
}

View File

@@ -10,6 +10,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
@@ -22,6 +23,25 @@ import (
"github.com/tidwall/sjson"
)
func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
if w == nil || len(chunk) == 0 {
return
}
if _, err := w.Write(chunk); err != nil {
return
}
if bytes.HasSuffix(chunk, []byte("\n\n")) {
return
}
suffix := []byte("\n\n")
if bytes.HasSuffix(chunk, []byte("\n")) {
suffix = []byte("\n")
}
if _, err := w.Write(suffix); err != nil {
return
}
}
// OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIResponsesAPIHandler struct {
@@ -271,11 +291,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk logic (matching forwardResponsesStream)
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
writeResponsesSSEChunk(c.Writer, chunk)
flusher.Flush()
// Continue
@@ -400,11 +416,7 @@ func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context,
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
writeResponsesSSEChunk(c.Writer, chunk)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {

View File

@@ -0,0 +1,52 @@
package openai
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
gin.SetMode(gin.TestMode)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
h := NewOpenAIResponsesAPIHandler(base)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
t.Fatalf("expected gin writer to implement http.Flusher")
}
data := make(chan []byte, 2)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}")
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}")
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
body := recorder.Body.String()
parts := strings.Split(strings.TrimSpace(body), "\n\n")
if len(parts) != 2 {
t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), body)
}
expectedPart1 := "data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}"
if parts[0] != expectedPart1 {
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
}
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
if parts[1] != expectedPart2 {
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
}
}

View File

@@ -33,9 +33,6 @@ const (
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
wsBodyLogMaxSize = 64 * 1024
wsBodyLogTruncated = "\n[websocket log truncated]\n"
)
var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -55,14 +52,14 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
return
}
passthroughSessionID := uuid.NewString()
clientRemoteAddr := ""
if c != nil && c.Request != nil {
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
}
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
retainResponsesWebsocketToolCaches(downstreamSessionKey)
clientIP := websocketClientAddress(c)
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
var wsTerminateErr error
var wsBodyLog strings.Builder
defer func() {
releaseResponsesWebsocketToolCaches(downstreamSessionKey)
if wsTerminateErr != nil {
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else {
@@ -167,6 +164,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
@@ -203,6 +203,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
}
func websocketClientAddress(c *gin.Context) string {
if c == nil || c.Request == nil {
return ""
}
return strings.TrimSpace(c.ClientIP())
}
func websocketUpgradeHeaders(req *http.Request) http.Header {
headers := http.Header{}
if req == nil {
@@ -277,6 +284,15 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
}
}
// Compaction can cause clients to replace local websocket history with a new
// compact transcript on the next `response.create`. When the input already
// contains historical model output items, treating it as an incremental append
// duplicates stale turn-state and can leave late orphaned function_call items.
if shouldReplaceWebsocketTranscript(rawJSON, nextInput) {
normalized := normalizeResponseTranscriptReplacement(rawJSON, lastRequest)
return normalized, bytes.Clone(normalized), nil
}
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
// Do not expand it into a full input transcript; upstream expects the incremental payload.
if allowIncrementalInputWithPreviousResponseID {
@@ -318,6 +334,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
if errDedupeFunctionCalls == nil {
mergedInput = dedupedInput
}
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
@@ -348,6 +368,91 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
return normalized, bytes.Clone(normalized), nil
}
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
if requestType != wsRequestTypeCreate && requestType != wsRequestTypeAppend {
return false
}
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
return false
}
if !nextInput.Exists() || !nextInput.IsArray() {
return false
}
for _, item := range nextInput.Array() {
switch strings.TrimSpace(item.Get("type").String()) {
case "function_call":
return true
case "message":
role := strings.TrimSpace(item.Get("role").String())
if role == "assistant" {
return true
}
}
}
return false
}
func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) []byte {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return bytes.Clone(normalized)
}
func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
rawArray = strings.TrimSpace(rawArray)
if rawArray == "" {
return "[]", nil
}
var items []json.RawMessage
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
return "", errUnmarshal
}
seenCallIDs := make(map[string]struct{}, len(items))
filtered := make([]json.RawMessage, 0, len(items))
for _, item := range items {
if len(item) == 0 {
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call" {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID != "" {
if _, ok := seenCallIDs[callID]; ok {
continue
}
seenCallIDs[callID] = struct{}{}
}
}
filtered = append(filtered, item)
}
out, errMarshal := json.Marshal(filtered)
if errMarshal != nil {
return "", errMarshal
}
return string(out), nil
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
@@ -613,6 +718,10 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
) ([]byte, error) {
completed := false
completedOutput := []byte("[]")
downstreamSessionKey := ""
if c != nil && c.Request != nil {
downstreamSessionKey = websocketDownstreamSessionKey(c.Request)
}
for {
select {
@@ -690,6 +799,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
payloads := websocketJSONPayloadsFromChunk(chunk)
for i := range payloads {
recordResponsesWebsocketToolCallsFromPayload(downstreamSessionKey, payloads[i])
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
completed = true
@@ -837,71 +947,18 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
if builder == nil {
return
}
if builder.Len() >= wsBodyLogMaxSize {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
if !appendWebsocketLogString(builder, "\n") {
return
}
builder.WriteString("\n")
}
if !appendWebsocketLogString(builder, "websocket.") {
return
}
if !appendWebsocketLogString(builder, eventType) {
return
}
if !appendWebsocketLogString(builder, "\n") {
return
}
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
appendWebsocketLogString(builder, wsBodyLogTruncated)
return
}
appendWebsocketLogString(builder, "\n")
}
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.WriteString(value)
return true
}
builder.WriteString(value[:remaining])
return false
}
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.Write(value)
return true
}
limit := remaining - reserveForSuffix
if limit < 0 {
limit = 0
}
if limit > len(value) {
limit = len(value)
}
builder.Write(value[:limit])
return false
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
}
func websocketPayloadEventType(payload []byte) string {
@@ -917,15 +974,8 @@ func websocketPayloadPreview(payload []byte) string {
if len(trimmedPayload) == 0 {
return "<empty>"
}
preview := trimmedPayload
if len(preview) > wsPayloadLogMaxSize {
preview = preview[:wsPayloadLogMaxSize]
}
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
previewText := strings.ReplaceAll(string(trimmedPayload), "\n", "\\n")
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
if len(trimmedPayload) > wsPayloadLogMaxSize {
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
}
return previewText
}

View File

@@ -10,6 +10,7 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -27,6 +28,12 @@ type websocketCaptureExecutor struct {
payloads [][]byte
}
type websocketCompactionCaptureExecutor struct {
mu sync.Mutex
streamPayloads [][]byte
compactPayload []byte
}
type orderedWebsocketSelector struct {
mu sync.Mutex
order []string
@@ -126,6 +133,52 @@ func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth,
return nil, errors.New("not implemented")
}
func (e *websocketCompactionCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketCompactionCaptureExecutor) Execute(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) {
e.mu.Lock()
e.compactPayload = bytes.Clone(req.Payload)
e.mu.Unlock()
if opts.Alt != "responses/compact" {
return coreexecutor.Response{}, fmt.Errorf("unexpected non-compact execute alt: %q", opts.Alt)
}
return coreexecutor.Response{Payload: []byte(`{"id":"cmp-1","object":"response.compaction"}`)}, nil
}
func (e *websocketCompactionCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
callIndex := len(e.streamPayloads)
e.streamPayloads = append(e.streamPayloads, bytes.Clone(req.Payload))
e.mu.Unlock()
var payload []byte
switch callIndex {
case 0:
payload = []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}]}}`)
case 1:
payload = []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[{"type":"message","id":"assistant-1"}]}}`)
default:
payload = []byte(`{"type":"response.completed","response":{"id":"resp-3","output":[{"type":"message","id":"assistant-2"}]}}`)
}
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Payload: payload}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketCompactionCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketCompactionCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCompactionCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
@@ -339,33 +392,6 @@ func TestAppendWebsocketEvent(t *testing.T) {
}
}
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
var builder strings.Builder
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
appendWebsocketEvent(&builder, "request", payload)
got := builder.String()
if len(got) > wsBodyLogMaxSize {
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
}
if !strings.Contains(got, wsBodyLogTruncated) {
t.Fatalf("expected truncation marker in body log")
}
}
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
initial := builder.String()
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
if builder.String() != initial {
t.Fatalf("builder grew after reaching limit")
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
@@ -390,6 +416,108 @@ func TestSetWebsocketRequestBody(t *testing.T) {
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`)
warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm)
if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" {
t.Fatalf("expected warmup output to remain")
}
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3", len(input))
}
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected first item: %s", input[0].Raw)
}
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted output: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanFunctionCall(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 1 {
t.Fatalf("repaired input len = %d, want 1", len(input))
}
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
callCache.record(sessionKey, "call-1", []byte(`{"type":"function_call","call_id":"call-1","name":"tool"}`))
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3", len(input))
}
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted call: %s", input[0].Raw)
}
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected output item: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 1 {
t.Fatalf("repaired input len = %d, want 1", len(input))
}
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
}
}
func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool","arguments":"{}"}]}}`)
recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload)
cached, ok := cache.get(sessionKey, "call-1")
if !ok {
t.Fatalf("expected cached tool call")
}
if gjson.GetBytes(cached, "type").String() != "function_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" {
t.Fatalf("unexpected cached tool call: %s", cached)
}
}
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -593,6 +721,31 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
}
}
func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, engine := gin.CreateTestContext(recorder)
if err := engine.SetTrustedProxies([]string{"0.0.0.0/0", "::/0"}); err != nil {
t.Fatalf("SetTrustedProxies: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/v1/responses/ws", nil)
req.RemoteAddr = "172.18.0.1:34282"
req.Header.Set("X-Forwarded-For", "203.0.113.7")
c.Request = req
if got := websocketClientAddress(c); got != strings.TrimSpace(c.ClientIP()) {
t.Fatalf("websocketClientAddress = %q, ClientIP = %q", got, c.ClientIP())
}
}
func TestWebsocketClientAddressReturnsEmptyForNilContext(t *testing.T) {
if got := websocketClientAddress(nil); got != "" {
t.Fatalf("websocketClientAddress(nil) = %q, want empty", got)
}
}
func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -662,3 +815,183 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
}
}
func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1","role":"assistant"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
t.Fatalf("previous_response_id must not exist in transcript replacement mode")
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 2 {
t.Fatalf("replacement input len = %d, want 2: %s", len(items), normalized)
}
if items[0].Get("id").String() != "fc-compact" || items[1].Get("id").String() != "msg-2" {
t.Fatalf("replacement transcript was not preserved as-is: %s", normalized)
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match replacement request")
}
}
func TestNormalizeResponsesWebsocketRequestDoesNotTreatDeveloperMessageAsReplacement(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1","role":"assistant"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"dev-1","role":"developer"},{"type":"message","id":"msg-2"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 4 {
t.Fatalf("merged input len = %d, want 4: %s", len(items), normalized)
}
if items[0].Get("id").String() != "msg-1" ||
items[1].Get("id").String() != "assistant-1" ||
items[2].Get("id").String() != "dev-1" ||
items[3].Get("id").String() != "msg-2" {
t.Fatalf("developer follow-up should preserve merge behavior: %s", normalized)
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match merged request")
}
}
func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 3 {
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
}
if items[0].Get("id").String() != "fc-1" ||
items[1].Get("id").String() != "tool-out-1" ||
items[2].Get("id").String() != "msg-2" {
t.Fatalf("unexpected merged input order: %s", normalized)
}
}
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &websocketCompactionCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
router.POST("/v1/responses/compact", h.Compact)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
if errClose := conn.Close(); errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
requests := []string{
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`,
}
for i := range requests {
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
}
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
}
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
}
}
compactResp, errPost := server.Client().Post(
server.URL+"/v1/responses/compact",
"application/json",
strings.NewReader(`{"model":"test-model","input":[{"type":"message","id":"summary-1"}]}`),
)
if errPost != nil {
t.Fatalf("compact request failed: %v", errPost)
}
if errClose := compactResp.Body.Close(); errClose != nil {
t.Fatalf("close compact response body: %v", errClose)
}
if compactResp.StatusCode != http.StatusOK {
t.Fatalf("compact status = %d, want %d", compactResp.StatusCode, http.StatusOK)
}
// Simulate a post-compaction client turn that replaces local history with a compacted transcript.
// The websocket handler must treat this as a state reset, not append it to stale pre-compaction state.
postCompact := `{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(postCompact)); errWrite != nil {
t.Fatalf("write post-compact websocket message: %v", errWrite)
}
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read post-compact websocket message: %v", errReadMessage)
}
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
t.Fatalf("post-compact payload type = %s, want %s", got, wsEventTypeCompleted)
}
executor.mu.Lock()
defer executor.mu.Unlock()
if executor.compactPayload == nil {
t.Fatalf("compact payload was not captured")
}
if len(executor.streamPayloads) != 3 {
t.Fatalf("stream payload count = %d, want 3", len(executor.streamPayloads))
}
merged := executor.streamPayloads[2]
items := gjson.GetBytes(merged, "input").Array()
if len(items) != 2 {
t.Fatalf("merged input len = %d, want 2: %s", len(items), merged)
}
if items[0].Get("id").String() != "fc-compact" ||
items[1].Get("id").String() != "msg-2" {
t.Fatalf("unexpected post-compact input order: %s", merged)
}
if items[0].Get("call_id").String() != "call-1" {
t.Fatalf("post-compact function call id = %s, want call-1", items[0].Get("call_id").String())
}
}

View File

@@ -0,0 +1,402 @@
package openai
import (
"encoding/json"
"net/http"
"strings"
"sync"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
websocketToolOutputCacheMaxPerSession = 256
websocketToolOutputCacheTTL = 30 * time.Minute
)
var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
var defaultWebsocketToolSessionRefs = newWebsocketToolSessionRefCounter()
type websocketToolOutputCache struct {
mu sync.Mutex
ttl time.Duration
maxPerSession int
sessions map[string]*websocketToolOutputSession
}
type websocketToolOutputSession struct {
lastSeen time.Time
outputs map[string]json.RawMessage
order []string
}
func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache {
if ttl < 0 {
ttl = websocketToolOutputCacheTTL
}
if maxPerSession <= 0 {
maxPerSession = websocketToolOutputCacheMaxPerSession
}
return &websocketToolOutputCache{
ttl: ttl,
maxPerSession: maxPerSession,
sessions: make(map[string]*websocketToolOutputSession),
}
}
func (c *websocketToolOutputCache) record(sessionKey string, callID string, item json.RawMessage) {
sessionKey = strings.TrimSpace(sessionKey)
callID = strings.TrimSpace(callID)
if sessionKey == "" || callID == "" || c == nil {
return
}
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
c.cleanupLocked(now)
session, ok := c.sessions[sessionKey]
if !ok || session == nil {
session = &websocketToolOutputSession{
lastSeen: now,
outputs: make(map[string]json.RawMessage),
}
c.sessions[sessionKey] = session
}
session.lastSeen = now
if _, exists := session.outputs[callID]; !exists {
session.order = append(session.order, callID)
}
session.outputs[callID] = append(json.RawMessage(nil), item...)
for len(session.order) > c.maxPerSession {
evict := session.order[0]
session.order = session.order[1:]
delete(session.outputs, evict)
}
}
func (c *websocketToolOutputCache) get(sessionKey string, callID string) (json.RawMessage, bool) {
sessionKey = strings.TrimSpace(sessionKey)
callID = strings.TrimSpace(callID)
if sessionKey == "" || callID == "" || c == nil {
return nil, false
}
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
c.cleanupLocked(now)
session, ok := c.sessions[sessionKey]
if !ok || session == nil {
return nil, false
}
session.lastSeen = now
item, ok := session.outputs[callID]
if !ok || len(item) == 0 {
return nil, false
}
return append(json.RawMessage(nil), item...), true
}
func (c *websocketToolOutputCache) cleanupLocked(now time.Time) {
if c == nil || c.ttl <= 0 {
return
}
for key, session := range c.sessions {
if session == nil {
delete(c.sessions, key)
continue
}
if now.Sub(session.lastSeen) > c.ttl {
delete(c.sessions, key)
}
}
}
func (c *websocketToolOutputCache) deleteSession(sessionKey string) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
delete(c.sessions, sessionKey)
}
func websocketDownstreamSessionKey(req *http.Request) string {
if req == nil {
return ""
}
if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" {
return requestID
}
if raw := strings.TrimSpace(req.Header.Get("X-Codex-Turn-Metadata")); raw != "" {
if sessionID := strings.TrimSpace(gjson.Get(raw, "session_id").String()); sessionID != "" {
return sessionID
}
}
if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" {
return sessionID
}
return ""
}
type websocketToolSessionRefCounter struct {
mu sync.Mutex
counts map[string]int
}
func newWebsocketToolSessionRefCounter() *websocketToolSessionRefCounter {
return &websocketToolSessionRefCounter{counts: make(map[string]int)}
}
func (c *websocketToolSessionRefCounter) acquire(sessionKey string) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.counts[sessionKey]++
}
func (c *websocketToolSessionRefCounter) release(sessionKey string) bool {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
count := c.counts[sessionKey]
if count <= 1 {
delete(c.counts, sessionKey)
return true
}
c.counts[sessionKey] = count - 1
return false
}
func retainResponsesWebsocketToolCaches(sessionKey string) {
if defaultWebsocketToolSessionRefs == nil {
return
}
defaultWebsocketToolSessionRefs.acquire(sessionKey)
}
func releaseResponsesWebsocketToolCaches(sessionKey string) {
if defaultWebsocketToolSessionRefs == nil {
return
}
if !defaultWebsocketToolSessionRefs.release(sessionKey) {
return
}
if defaultWebsocketToolOutputCache != nil {
defaultWebsocketToolOutputCache.deleteSession(sessionKey)
}
if defaultWebsocketToolCallCache != nil {
defaultWebsocketToolCallCache.deleteSession(sessionKey)
}
}
func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte {
return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload)
}
func repairResponsesWebsocketToolCallsWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) []byte {
return repairResponsesWebsocketToolCallsWithCaches(cache, nil, sessionKey, payload)
}
func repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache *websocketToolOutputCache, sessionKey string, payload []byte) []byte {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || outputCache == nil || len(payload) == 0 {
return payload
}
input := gjson.GetBytes(payload, "input")
if !input.Exists() || !input.IsArray() {
return payload
}
allowOrphanOutputs := strings.TrimSpace(gjson.GetBytes(payload, "previous_response_id").String()) != ""
updatedRaw, errRepair := repairResponsesToolCallsArray(outputCache, callCache, sessionKey, input.Raw, allowOrphanOutputs)
if errRepair != nil || updatedRaw == "" || updatedRaw == input.Raw {
return payload
}
updated, errSet := sjson.SetRawBytes(payload, "input", []byte(updatedRaw))
if errSet != nil {
return payload
}
return updated
}
func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCache, sessionKey string, rawArray string, allowOrphanOutputs bool) (string, error) {
rawArray = strings.TrimSpace(rawArray)
if rawArray == "" {
return "[]", nil
}
var items []json.RawMessage
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
return "", errUnmarshal
}
// First pass: record tool outputs and remember which call_ids have outputs in this payload.
outputPresent := make(map[string]struct{}, len(items))
callPresent := make(map[string]struct{}, len(items))
for _, item := range items {
if len(item) == 0 {
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
switch itemType {
case "function_call_output":
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
continue
}
outputPresent[callID] = struct{}{}
outputCache.record(sessionKey, callID, item)
case "function_call":
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
continue
}
callPresent[callID] = struct{}{}
if callCache != nil {
callCache.record(sessionKey, callID, item)
}
}
}
filtered := make([]json.RawMessage, 0, len(items))
insertedCalls := make(map[string]struct{}, len(items))
for _, item := range items {
if len(item) == 0 {
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call_output" {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
// Upstream rejects tool outputs without a call_id; drop it.
continue
}
if allowOrphanOutputs {
filtered = append(filtered, item)
continue
}
if _, ok := callPresent[callID]; ok {
filtered = append(filtered, item)
continue
}
if callCache != nil {
if cached, ok := callCache.get(sessionKey, callID); ok {
if _, already := insertedCalls[callID]; !already {
filtered = append(filtered, cached)
insertedCalls[callID] = struct{}{}
callPresent[callID] = struct{}{}
}
filtered = append(filtered, item)
continue
}
}
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
continue
}
if itemType != "function_call" {
filtered = append(filtered, item)
continue
}
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
// Upstream rejects tool calls without a call_id; drop it.
continue
}
if _, ok := outputPresent[callID]; ok {
filtered = append(filtered, item)
continue
}
if cached, ok := outputCache.get(sessionKey, callID); ok {
filtered = append(filtered, item)
filtered = append(filtered, cached)
outputPresent[callID] = struct{}{}
continue
}
// Drop orphaned function_call items; upstream rejects transcripts with missing outputs.
}
out, errMarshal := json.Marshal(filtered)
if errMarshal != nil {
return "", errMarshal
}
return string(out), nil
}
func recordResponsesWebsocketToolCallsFromPayload(sessionKey string, payload []byte) {
recordResponsesWebsocketToolCallsFromPayloadWithCache(defaultWebsocketToolCallCache, sessionKey, payload)
}
func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || cache == nil || len(payload) == 0 {
return
}
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
switch eventType {
case "response.completed":
output := gjson.GetBytes(payload, "response.output")
if !output.Exists() || !output.IsArray() {
return
}
for _, item := range output.Array() {
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
continue
}
callID := strings.TrimSpace(item.Get("call_id").String())
if callID == "" {
continue
}
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
}
case "response.output_item.added", "response.output_item.done":
item := gjson.GetBytes(payload, "item")
if !item.Exists() || !item.IsObject() {
return
}
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
return
}
callID := strings.TrimSpace(item.Get("call_id").String())
if callID == "" {
return
}
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
}
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
@@ -437,6 +438,31 @@ func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []stri
return []string{resolved}
}
func (m *Manager) selectionModelForAuth(auth *Auth, routeModel string) string {
requestedModel := rewriteModelForAuth(routeModel, auth)
if strings.TrimSpace(requestedModel) == "" {
requestedModel = strings.TrimSpace(routeModel)
}
resolvedModel := m.applyOAuthModelAlias(auth, requestedModel)
if strings.TrimSpace(resolvedModel) == "" {
resolvedModel = requestedModel
}
return resolvedModel
}
func (m *Manager) selectionModelKeyForAuth(auth *Auth, routeModel string) string {
return canonicalModelKey(m.selectionModelForAuth(auth, routeModel))
}
func (m *Manager) stateModelForExecution(auth *Auth, routeModel, upstreamModel string, pooled bool) string {
stateModel := executionResultModel(routeModel, upstreamModel, pooled)
selectionModel := m.selectionModelForAuth(auth, routeModel)
if canonicalModelKey(selectionModel) == canonicalModelKey(upstreamModel) && strings.TrimSpace(selectionModel) != "" {
return strings.TrimSpace(upstreamModel)
}
return stateModel
}
func executionResultModel(routeModel, upstreamModel string, pooled bool) string {
if pooled {
if resolved := strings.TrimSpace(upstreamModel); resolved != "" {
@@ -449,14 +475,14 @@ func executionResultModel(routeModel, upstreamModel string, pooled bool) string
return strings.TrimSpace(upstreamModel)
}
func filterExecutionModels(auth *Auth, routeModel string, candidates []string, pooled bool) []string {
func (m *Manager) filterExecutionModels(auth *Auth, routeModel string, candidates []string, pooled bool) []string {
if len(candidates) == 0 {
return nil
}
now := time.Now()
out := make([]string, 0, len(candidates))
for _, upstreamModel := range candidates {
stateModel := executionResultModel(routeModel, upstreamModel, pooled)
stateModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
blocked, _, _ := isAuthBlockedForModel(auth, stateModel, now)
if blocked {
continue
@@ -469,7 +495,7 @@ func filterExecutionModels(auth *Auth, routeModel string, candidates []string, p
func (m *Manager) preparedExecutionModels(auth *Auth, routeModel string) ([]string, bool) {
candidates := m.executionModelCandidates(auth, routeModel)
pooled := len(candidates) > 1
return filterExecutionModels(auth, routeModel, candidates, pooled), pooled
return m.filterExecutionModels(auth, routeModel, candidates, pooled), pooled
}
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
@@ -477,6 +503,83 @@ func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string
return models
}
func (m *Manager) availableAuthsForRouteModel(auths []*Auth, provider, routeModel string, now time.Time) ([]*Auth, error) {
if len(auths) == 0 {
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
}
availableByPriority := make(map[int][]*Auth)
cooldownCount := 0
var earliest time.Time
for _, candidate := range auths {
checkModel := m.selectionModelForAuth(candidate, routeModel)
blocked, reason, next := isAuthBlockedForModel(candidate, checkModel, now)
if !blocked {
priority := authPriority(candidate)
availableByPriority[priority] = append(availableByPriority[priority], candidate)
continue
}
if reason == blockReasonCooldown {
cooldownCount++
if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) {
earliest = next
}
}
}
if len(availableByPriority) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return nil, newModelCooldownError(routeModel, providerForError, resetIn)
}
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
}
bestPriority := 0
found := false
for priority := range availableByPriority {
if !found || priority > bestPriority {
bestPriority = priority
found = true
}
}
available := availableByPriority[bestPriority]
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, nil
}
func selectionArgForSelector(selector Selector, routeModel string) string {
if isBuiltInSelector(selector) {
return ""
}
return routeModel
}
func (m *Manager) authSupportsRouteModel(registryRef *registry.ModelRegistry, auth *Auth, routeModel string) bool {
if registryRef == nil || auth == nil {
return true
}
routeKey := canonicalModelKey(routeModel)
if routeKey == "" {
return true
}
if registryRef.ClientSupportsModel(auth.ID, routeKey) {
return true
}
selectionKey := m.selectionModelKeyForAuth(auth, routeModel)
return selectionKey != "" && selectionKey != routeKey && registryRef.ClientSupportsModel(auth.ID, selectionKey)
}
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
if ch == nil {
return
@@ -627,7 +730,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
}
var lastErr error
for idx, execModel := range execModels {
resultModel := executionResultModel(routeModel, execModel, pooled)
resultModel := m.stateModelForExecution(auth, routeModel, execModel, pooled)
execReq := req
execReq.Model = execModel
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
@@ -1107,7 +1210,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
attempted[auth.ID] = struct{}{}
var authErr error
for _, upstreamModel := range models {
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
@@ -1185,7 +1288,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
attempted[auth.ID] = struct{}{}
var authErr error
for _, upstreamModel := range models {
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
@@ -1734,77 +1837,79 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
}
} else {
if result.Model != "" {
state := ensureModelState(auth, result.Model)
state.Unavailable = true
state.Status = StatusError
state.UpdatedAt = now
if result.Error != nil {
state.LastError = cloneError(result.Error)
state.StatusMessage = result.Error.Message
auth.LastError = cloneError(result.Error)
auth.StatusMessage = result.Error.Message
}
if !isRequestScopedNotFoundResultError(result.Error) {
state := ensureModelState(auth, result.Model)
state.Unavailable = true
state.Status = StatusError
state.UpdatedAt = now
if result.Error != nil {
state.LastError = cloneError(result.Error)
state.StatusMessage = result.Error.Message
auth.LastError = cloneError(result.Error)
auth.StatusMessage = result.Error.Message
}
statusCode := statusCodeFromResult(result.Error)
if isModelSupportResultError(result.Error) {
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "model_not_supported"
shouldSuspendModel = true
} else {
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
case 402, 403:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
case 404:
statusCode := statusCodeFromResult(result.Error)
if isModelSupportResultError(result.Error) {
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
suspendReason = "model_not_supported"
shouldSuspendModel = true
case 429:
var next time.Time
backoffLevel := state.Quota.BackoffLevel
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
backoffLevel = nextLevel
}
state.NextRetryAfter = next
state.Quota = QuotaState{
Exceeded: true,
Reason: "quota",
NextRecoverAt: next,
BackoffLevel: backoffLevel,
}
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
case 408, 500, 502, 503, 504:
if quotaCooldownDisabledForAuth(auth) {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
} else {
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
case 402, 403:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
case 404:
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
shouldSuspendModel = true
case 429:
var next time.Time
backoffLevel := state.Quota.BackoffLevel
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
backoffLevel = nextLevel
}
state.NextRetryAfter = next
state.Quota = QuotaState{
Exceeded: true,
Reason: "quota",
NextRecoverAt: next,
BackoffLevel: backoffLevel,
}
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
case 408, 500, 502, 503, 504:
if quotaCooldownDisabledForAuth(auth) {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
state.NextRetryAfter = next
}
default:
state.NextRetryAfter = time.Time{}
}
default:
state.NextRetryAfter = time.Time{}
}
}
auth.Status = StatusError
auth.UpdatedAt = now
updateAggregatedAvailability(auth, now)
auth.Status = StatusError
auth.UpdatedAt = now
updateAggregatedAvailability(auth, now)
}
} else {
applyAuthFailureState(auth, result.Error, result.RetryAfter, now)
}
@@ -2056,11 +2161,29 @@ func isModelSupportResultError(err *Error) bool {
return isModelSupportErrorMessage(err.Message)
}
func isRequestScopedNotFoundMessage(message string) bool {
if message == "" {
return false
}
lower := strings.ToLower(message)
return strings.Contains(lower, "item with id") &&
strings.Contains(lower, "not found") &&
strings.Contains(lower, "items are not persisted when `store` is set to false")
}
func isRequestScopedNotFoundResultError(err *Error) bool {
if err == nil || statusCodeFromResult(err) != http.StatusNotFound {
return false
}
return isRequestScopedNotFoundMessage(err.Message)
}
// isRequestInvalidError returns true if the error represents a client request
// error that should not be retried. Specifically, it treats 400 responses with
// "invalid_request_error" and all 422 responses as request-shape failures,
// where switching auths or pooled upstream models will not help. Model-support
// errors are excluded so routing can fall through to another auth or upstream.
// "invalid_request_error", request-scoped 404 item misses caused by `store=false`,
// and all 422 responses as request-shape failures, where switching auths or
// pooled upstream models will not help. Model-support errors are excluded so
// routing can fall through to another auth or upstream.
func isRequestInvalidError(err error) bool {
if err == nil {
return false
@@ -2072,6 +2195,8 @@ func isRequestInvalidError(err error) bool {
switch status {
case http.StatusBadRequest:
return strings.Contains(err.Error(), "invalid_request_error")
case http.StatusNotFound:
return isRequestScopedNotFoundMessage(err.Error())
case http.StatusUnprocessableEntity:
return true
default:
@@ -2083,6 +2208,9 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
if auth == nil {
return
}
if isRequestScopedNotFoundResultError(resultErr) {
return
}
auth.Unavailable = true
auth.Status = StatusError
auth.UpdatedAt = now
@@ -2246,6 +2374,13 @@ func shouldRetrySchedulerPick(err error) bool {
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
}
func (m *Manager) routeAwareSelectionRequired(auth *Auth, routeModel string) bool {
if auth == nil || strings.TrimSpace(routeModel) == "" {
return false
}
return m.selectionModelKeyForAuth(auth, routeModel) != canonicalModelKey(routeModel)
}
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
@@ -2275,7 +2410,7 @@ func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, op
if _, used := tried[candidate.ID]; used {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
if modelKey != "" && !m.authSupportsRouteModel(registryRef, candidate, model) {
continue
}
candidates = append(candidates, candidate)
@@ -2284,7 +2419,12 @@ func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, op
m.mu.RUnlock()
return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, provider, model, opts, candidates)
available, errAvailable := m.availableAuthsForRouteModel(candidates, provider, model, time.Now())
if errAvailable != nil {
m.mu.RUnlock()
return nil, nil, errAvailable
}
selected, errPick := m.selector.Pick(ctx, provider, selectionArgForSelector(m.selector, model), opts, available)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, errPick
@@ -2310,6 +2450,22 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
if !m.useSchedulerFastPath() {
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
if strings.TrimSpace(model) != "" {
m.mu.RLock()
for _, candidate := range m.auths {
if candidate == nil || candidate.Provider != provider || candidate.Disabled {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if m.routeAwareSelectionRequired(candidate, model) {
m.mu.RUnlock()
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
}
m.mu.RUnlock()
}
executor, okExecutor := m.Executor(provider)
if !okExecutor {
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
@@ -2383,7 +2539,7 @@ func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, m
if _, ok := m.executors[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
if modelKey != "" && !m.authSupportsRouteModel(registryRef, candidate, model) {
continue
}
candidates = append(candidates, candidate)
@@ -2392,7 +2548,12 @@ func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, m
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
available, errAvailable := m.availableAuthsForRouteModel(candidates, "mixed", model, time.Now())
if errAvailable != nil {
m.mu.RUnlock()
return nil, nil, "", errAvailable
}
selected, errPick := m.selector.Pick(ctx, "mixed", selectionArgForSelector(m.selector, model), opts, available)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, "", errPick
@@ -2444,6 +2605,29 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
if len(eligibleProviders) == 0 {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
if strings.TrimSpace(model) != "" {
providerSet := make(map[string]struct{}, len(eligibleProviders))
for _, providerKey := range eligibleProviders {
providerSet[providerKey] = struct{}{}
}
m.mu.RLock()
for _, candidate := range m.auths {
if candidate == nil || candidate.Disabled {
continue
}
if _, ok := providerSet[strings.TrimSpace(strings.ToLower(candidate.Provider))]; !ok {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if m.routeAwareSelectionRequired(candidate, model) {
m.mu.RUnlock()
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
}
}
m.mu.RUnlock()
}
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {

View File

@@ -0,0 +1,111 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
"time"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type aliasRoutingExecutor struct {
id string
mu sync.Mutex
executeModels []string
}
func (e *aliasRoutingExecutor) Identifier() string { return e.id }
func (e *aliasRoutingExecutor) Execute(_ context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
e.mu.Unlock()
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *aliasRoutingExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "ExecuteStream not implemented"}
}
func (e *aliasRoutingExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *aliasRoutingExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "CountTokens not implemented"}
}
func (e *aliasRoutingExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
}
func (e *aliasRoutingExecutor) ExecuteModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeModels))
copy(out, e.executeModels)
return out
}
func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) {
const (
provider = "antigravity"
routeModel = "claude-opus-4-6"
targetModel = "claude-opus-4-6-thinking"
)
manager := NewManager(nil, nil, nil)
executor := &aliasRoutingExecutor{id: provider}
manager.RegisterExecutor(executor)
manager.SetOAuthModelAlias(map[string][]internalconfig.OAuthModelAlias{
provider: {{
Name: targetModel,
Alias: routeModel,
Fork: true,
}},
})
auth := &Auth{
ID: "oauth-alias-auth",
Provider: provider,
Status: StatusActive,
ModelStates: map[string]*ModelState{
routeModel: {
Unavailable: true,
Status: StatusError,
NextRetryAfter: time.Now().Add(1 * time.Hour),
},
},
}
if _, errRegister := manager.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: routeModel}, {ID: targetModel}})
t.Cleanup(func() {
reg.UnregisterClient(auth.ID)
})
manager.RefreshSchedulerEntry(auth.ID)
resp, errExecute := manager.Execute(context.Background(), []string{provider}, cliproxyexecutor.Request{Model: routeModel}, cliproxyexecutor.Options{})
if errExecute != nil {
t.Fatalf("execute error = %v, want success", errExecute)
}
if string(resp.Payload) != targetModel {
t.Fatalf("execute payload = %q, want %q", string(resp.Payload), targetModel)
}
gotModels := executor.ExecuteModels()
if len(gotModels) != 1 {
t.Fatalf("execute models len = %d, want 1", len(gotModels))
}
if gotModels[0] != targetModel {
t.Fatalf("execute model = %q, want %q", gotModels[0], targetModel)
}
}

View File

@@ -12,6 +12,8 @@ import (
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
const requestScopedNotFoundMessage = "Item with id 'rs_0b5f3eb6f51f175c0169ca74e4a85881998539920821603a74' not found. Items are not persisted when `store` is set to false. Try again with `store` set to true, or remove this item from your input."
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
m := NewManager(nil, nil, nil)
m.SetRetryConfig(3, 30*time.Second, 0)
@@ -447,3 +449,114 @@ func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
}
func TestManager_MarkResult_RequestScopedNotFoundDoesNotCooldownAuth(t *testing.T) {
m := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "openai",
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "gpt-4.1"
m.MarkResult(context.Background(), Result{
AuthID: auth.ID,
Provider: auth.Provider,
Model: model,
Success: false,
Error: &Error{
HTTPStatus: http.StatusNotFound,
Message: requestScopedNotFoundMessage,
},
})
updated, ok := m.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
if updated.Unavailable {
t.Fatalf("expected request-scoped 404 to keep auth available")
}
if !updated.NextRetryAfter.IsZero() {
t.Fatalf("expected request-scoped 404 to keep auth cooldown unset, got %v", updated.NextRetryAfter)
}
if state := updated.ModelStates[model]; state != nil {
t.Fatalf("expected request-scoped 404 to avoid model cooldown state, got %#v", state)
}
}
func TestManager_RequestScopedNotFoundStopsRetryWithoutSuspendingAuth(t *testing.T) {
m := NewManager(nil, nil, nil)
executor := &authFallbackExecutor{
id: "openai",
executeErrors: map[string]error{
"aa-bad-auth": &Error{
HTTPStatus: http.StatusNotFound,
Message: requestScopedNotFoundMessage,
},
},
}
m.RegisterExecutor(executor)
model := "gpt-4.1"
badAuth := &Auth{ID: "aa-bad-auth", Provider: "openai"}
goodAuth := &Auth{ID: "bb-good-auth", Provider: "openai"}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(badAuth.ID, "openai", []*registry.ModelInfo{{ID: model}})
reg.RegisterClient(goodAuth.ID, "openai", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() {
reg.UnregisterClient(badAuth.ID)
reg.UnregisterClient(goodAuth.ID)
})
if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil {
t.Fatalf("register bad auth: %v", errRegister)
}
if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil {
t.Fatalf("register good auth: %v", errRegister)
}
_, errExecute := m.Execute(context.Background(), []string{"openai"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{})
if errExecute == nil {
t.Fatal("expected request-scoped not-found error")
}
errResult, ok := errExecute.(*Error)
if !ok {
t.Fatalf("expected *Error, got %T", errExecute)
}
if errResult.HTTPStatus != http.StatusNotFound {
t.Fatalf("status = %d, want %d", errResult.HTTPStatus, http.StatusNotFound)
}
if errResult.Message != requestScopedNotFoundMessage {
t.Fatalf("message = %q, want %q", errResult.Message, requestScopedNotFoundMessage)
}
got := executor.ExecuteCalls()
want := []string{badAuth.ID}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d auth = %q, want %q", i, got[i], want[i])
}
}
updatedBad, ok := m.GetByID(badAuth.ID)
if !ok || updatedBad == nil {
t.Fatalf("expected bad auth to remain registered")
}
if updatedBad.Unavailable {
t.Fatalf("expected request-scoped 404 to keep bad auth available")
}
if !updatedBad.NextRetryAfter.IsZero() {
t.Fatalf("expected request-scoped 404 to keep bad auth cooldown unset, got %v", updatedBad.NextRetryAfter)
}
if state := updatedBad.ModelStates[model]; state != nil {
t.Fatalf("expected request-scoped 404 to avoid bad auth model cooldown state, got %#v", state)
}
}

View File

@@ -219,6 +219,19 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
if len(normalized) == 0 {
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
if len(normalized) == 1 {
// When a single provider is eligible, reuse pickSingle so provider-specific preferences
// (for example Codex websocket transport) are applied consistently.
providerKey := normalized[0]
picked, errPick := s.pickSingle(ctx, providerKey, model, opts, tried)
if errPick != nil {
return nil, "", errPick
}
if picked == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
return picked, providerKey, nil
}
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
modelKey := canonicalModelKey(model)
@@ -293,12 +306,46 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
}
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
start := 0
if len(normalized) > 0 {
start = s.mixedCursors[cursorKey] % len(normalized)
weights := make([]int, len(normalized))
segmentStarts := make([]int, len(normalized))
segmentEnds := make([]int, len(normalized))
totalWeight := 0
for providerIndex, shard := range candidateShards {
segmentStarts[providerIndex] = totalWeight
if shard != nil {
weights[providerIndex] = shard.readyCountAtPriorityLocked(false, bestPriority)
}
totalWeight += weights[providerIndex]
segmentEnds[providerIndex] = totalWeight
}
if totalWeight == 0 {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
startSlot := s.mixedCursors[cursorKey] % totalWeight
startProviderIndex := -1
for providerIndex := range normalized {
if weights[providerIndex] == 0 {
continue
}
if startSlot < segmentEnds[providerIndex] {
startProviderIndex = providerIndex
break
}
}
if startProviderIndex < 0 {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
slot := startSlot
for offset := 0; offset < len(normalized); offset++ {
providerIndex := (start + offset) % len(normalized)
providerIndex := (startProviderIndex + offset) % len(normalized)
if weights[providerIndex] == 0 {
continue
}
if providerIndex != startProviderIndex {
slot = segmentStarts[providerIndex]
}
providerKey := normalized[providerIndex]
shard := candidateShards[providerIndex]
if shard == nil {
@@ -308,7 +355,7 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
if picked == nil {
continue
}
s.mixedCursors[cursorKey] = providerIndex + 1
s.mixedCursors[cursorKey] = slot + 1
return picked, providerKey, nil
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
@@ -667,16 +714,25 @@ func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predic
if m == nil {
return 0, false
}
if preferWebsocket {
// When downstream is websocket and Codex supports websocket transport, prefer websocket-enabled
// credentials even if they are in a lower priority tier than HTTP-only credentials.
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
if bucket.ws.pickFirst(predicate) != nil {
return priority, true
}
}
}
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
if view.pickFirst(predicate) != nil {
if bucket.all.pickFirst(predicate) != nil {
return priority, true
}
}
@@ -694,7 +750,7 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit
return nil
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
if preferWebsocket && bucket.ws.pickFirst(predicate) != nil {
view = &bucket.ws
}
var picked *scheduledAuth
@@ -709,6 +765,20 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit
return picked.auth
}
func (m *modelScheduler) readyCountAtPriorityLocked(preferWebsocket bool, priority int) int {
if m == nil {
return 0
}
bucket := m.readyByPriority[priority]
if bucket == nil {
return 0
}
if preferWebsocket && len(bucket.ws.flat) > 0 {
return len(bucket.ws.flat)
}
return len(bucket.all.flat)
}
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
now := time.Now()

View File

@@ -208,7 +208,33 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T)
}
}
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledAcrossPriorities(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "codex-http", Provider: "codex", Attributes: map[string]string{"priority": "10"}},
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"priority": "0", "websockets": "true"}},
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"priority": "0", "websockets": "true"}},
)
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_MixedProvidersUsesWeightedProviderRotationOverReadyCandidates(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
@@ -218,8 +244,8 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *
&Auth{ID: "claude-a", Provider: "claude"},
)
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
@@ -272,7 +298,7 @@ func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
}
}
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
func TestManager_PickNextMixed_UsesWeightedProviderRotationBeforeCredentialRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
@@ -288,8 +314,8 @@ func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *t
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
@@ -399,8 +425,8 @@ func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {

View File

@@ -347,6 +347,7 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
log.Errorf("failed to disable auth %s: %v", id, err)
}
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
executor.CloseCodexWebsocketSessionsForAuthID(existing.ID, "auth_removed")
s.ensureExecutorsForAuth(existing)
}
}