mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-03 11:12:46 +00:00
Compare commits
138 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
516d22c695 | ||
|
|
73cda6e836 | ||
|
|
0805989ee5 | ||
|
|
75da02af55 | ||
|
|
ab9ebea592 | ||
|
|
7ee37ee4b9 | ||
|
|
837afffb31 | ||
|
|
03a1bac898 | ||
|
|
3171d524f0 | ||
|
|
3e78a8d500 | ||
|
|
fcba912cc4 | ||
|
|
7170eeea5f | ||
|
|
e3eb048c7a | ||
|
|
a59e92435b | ||
|
|
108895fc04 | ||
|
|
abc293c642 | ||
|
|
bb44671845 | ||
|
|
09e480036a | ||
|
|
249f969110 | ||
|
|
4f8acec2d8 | ||
|
|
34339f61ee | ||
|
|
4045378cb4 | ||
|
|
2df35449fe | ||
|
|
c744179645 | ||
|
|
9720b03a6b | ||
|
|
f2c0f3d325 | ||
|
|
4f99bc54f1 | ||
|
|
913f4a9c5f | ||
|
|
25d1c18a3f | ||
|
|
d09dd4d0b2 | ||
|
|
474fb042da | ||
|
|
8435c3d7be | ||
|
|
e783d0a62e | ||
|
|
b05f575e9b | ||
|
|
f5e9f01811 | ||
|
|
ff7dbb5867 | ||
|
|
e34b2b4f1d | ||
|
|
15c2f274ea | ||
|
|
37249339ac | ||
|
|
c422d16beb | ||
|
|
66cd50f603 | ||
|
|
caa529c282 | ||
|
|
51a4379bf4 | ||
|
|
acf98ed10e | ||
|
|
d1c07a091e | ||
|
|
c1a8adf1ab | ||
|
|
08e078fc25 | ||
|
|
105a21548f | ||
|
|
1734aa1664 | ||
|
|
ca11b236a7 | ||
|
|
6fdff8227d | ||
|
|
330e12d3c2 | ||
|
|
bd09c0bf09 | ||
|
|
b468ca79c3 | ||
|
|
d2c7e4e96a | ||
|
|
1c7003ff68 | ||
|
|
1b44364e78 | ||
|
|
ec77f4a4f5 | ||
|
|
f611dd6e96 | ||
|
|
07b7c1a1e0 | ||
|
|
51fd58d74f | ||
|
|
faae9c2f7c | ||
|
|
bc3a6e4646 | ||
|
|
b09b03e35e | ||
|
|
16231947e7 | ||
|
|
39b9a38fbc | ||
|
|
bd855abec9 | ||
|
|
7c3c2e9f64 | ||
|
|
c10f8ae2e2 | ||
|
|
a0bf33eca6 | ||
|
|
88dd9c715d | ||
|
|
a3e21df814 | ||
|
|
d3b94c9241 | ||
|
|
c1d7599829 | ||
|
|
d11936f292 | ||
|
|
17363edf25 | ||
|
|
279cbbbb8a | ||
|
|
486cd4c343 | ||
|
|
25feceb783 | ||
|
|
d26752250d | ||
|
|
b15453c369 | ||
|
|
04ba8c8bc3 | ||
|
|
6570692291 | ||
|
|
f73d55ddaa | ||
|
|
13aa5b3375 | ||
|
|
0fcc02fbea | ||
|
|
c03883ccf0 | ||
|
|
134a9eac9d | ||
|
|
6d8de0ade4 | ||
|
|
1587ff5e74 | ||
|
|
f033d3a6df | ||
|
|
145e0e0b5d | ||
|
|
f8d1bc06ea | ||
|
|
d5930f4e44 | ||
|
|
9b7d7021af | ||
|
|
e41c22ef44 | ||
|
|
55271403fb | ||
|
|
36fba66619 | ||
|
|
b9b127a7ea | ||
|
|
2741e7b7b3 | ||
|
|
1767a56d4f | ||
|
|
779e6c2d2f | ||
|
|
73c831747b | ||
|
|
b8b89f34f4 | ||
|
|
1fa094dac6 | ||
|
|
e5d3541b5a | ||
|
|
79755e76ea | ||
|
|
35f158d526 | ||
|
|
6962e09dd9 | ||
|
|
4c4cbd44da | ||
|
|
26eca8b6ba | ||
|
|
62b17f40a1 | ||
|
|
511b8a992e | ||
|
|
0ab977c236 | ||
|
|
224f0de353 | ||
|
|
d54de441d3 | ||
|
|
7386a70724 | ||
|
|
1b7447b682 | ||
|
|
40dee4453a | ||
|
|
8902e1cccb | ||
|
|
de5fe71478 | ||
|
|
dcfbec2990 | ||
|
|
c95620f90e | ||
|
|
754f3bcbc3 | ||
|
|
36973d4a6f | ||
|
|
9613f0b3f9 | ||
|
|
274f29e26b | ||
|
|
c8e79c3787 | ||
|
|
8afef43887 | ||
|
|
c1083cbfc6 | ||
|
|
c89d19b300 | ||
|
|
19c52bcb60 | ||
|
|
cc32f5ff61 | ||
|
|
fbff68b9e0 | ||
|
|
7e1a543b79 | ||
|
|
74b862d8b8 | ||
|
|
5c817a9b42 | ||
|
|
5da0decef6 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
# Binaries
|
||||
cli-proxy-api
|
||||
cliproxy
|
||||
/server
|
||||
*.exe
|
||||
|
||||
|
||||
@@ -36,13 +37,13 @@ GEMINI.md
|
||||
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.worktrees/
|
||||
.codex/*
|
||||
.claude/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# CLIProxyAPI Plus
|
||||
|
||||
[English](README.md) | 中文 | [日本語](README_JA.md)
|
||||
[English](README.md) | 中文
|
||||
|
||||
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
||||
|
||||
|
||||
187
README_JA.md
187
README_JA.md
@@ -1,187 +0,0 @@
|
||||
# CLI Proxy API
|
||||
|
||||
[English](README.md) | [中文](README_CN.md) | 日本語
|
||||
|
||||
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
|
||||
|
||||
OAuth経由でOpenAI Codex(GPTモデル)およびClaude Codeもサポートしています。
|
||||
|
||||
ローカルまたはマルチアカウントのCLIアクセスを、OpenAI(Responses含む)/Gemini/Claude互換のクライアントやSDKで利用できます。
|
||||
|
||||
## スポンサー
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
|
||||
|
||||
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7および(GLM-5はProユーザーのみ利用可能)モデルを10以上の人気AIコーディングツール(Claude Code、Cline、Roo Codeなど)で利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
|
||||
|
||||
GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
|
||||
---
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割(90% OFF)</b> という驚異的な価格でご利用いただけます!</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## 概要
|
||||
|
||||
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
|
||||
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
|
||||
- OAuthログインによるClaude Codeサポート
|
||||
- OAuthログインによるQwen Codeサポート
|
||||
- OAuthログインによるiFlowサポート
|
||||
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
|
||||
- ストリーミングおよび非ストリーミングレスポンス
|
||||
- 関数呼び出し/ツールのサポート
|
||||
- マルチモーダル入力サポート(テキストと画像)
|
||||
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||
- シンプルなCLI認証フロー(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||
- Generative Language APIキーのサポート
|
||||
- AI Studioビルドのマルチアカウント負荷分散
|
||||
- Gemini CLIのマルチアカウント負荷分散
|
||||
- Claude Codeのマルチアカウント負荷分散
|
||||
- Qwen Codeのマルチアカウント負荷分散
|
||||
- iFlowのマルチアカウント負荷分散
|
||||
- OpenAI Codexのマルチアカウント負荷分散
|
||||
- 設定によるOpenAI互換アップストリームプロバイダー(例:OpenRouter)
|
||||
- プロキシ埋め込み用の再利用可能なGo SDK(`docs/sdk-usage.md`を参照)
|
||||
|
||||
## はじめに
|
||||
|
||||
CLIProxyAPIガイド:[https://help.router-for.me/](https://help.router-for.me/)
|
||||
|
||||
## 管理API
|
||||
|
||||
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
|
||||
|
||||
## Amp CLIサポート
|
||||
|
||||
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます:
|
||||
|
||||
- Ampの APIパターン用のプロバイダールートエイリアス(`/api/provider/{provider}/v1...`)
|
||||
- OAuth認証およびアカウント機能用の管理プロキシ
|
||||
- 自動ルーティングによるスマートモデルフォールバック
|
||||
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
||||
|
||||
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDKドキュメント
|
||||
|
||||
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
|
||||
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
|
||||
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
|
||||
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
|
||||
- カスタムプロバイダーの例:`examples/custom-provider`
|
||||
|
||||
## コントリビューション
|
||||
|
||||
コントリビューションを歓迎します!お気軽にPull Requestを送ってください。
|
||||
|
||||
1. リポジトリをフォーク
|
||||
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`)
|
||||
3. 変更をコミット(`git commit -m 'Add some amazing feature'`)
|
||||
4. ブランチにプッシュ(`git push origin feature/amazing-feature`)
|
||||
5. Pull Requestを作成
|
||||
|
||||
## 関連プロジェクト
|
||||
|
||||
CLIProxyAPIをベースにした以下のプロジェクトがあります:
|
||||
|
||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
||||
|
||||
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
|
||||
|
||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
||||
|
||||
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
|
||||
|
||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
||||
|
||||
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデル(Gemini、Codex、Antigravity)を即座に切り替えるCLIラッパー - APIキー不要
|
||||
|
||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
||||
|
||||
CLIProxyAPI管理用のmacOSネイティブGUI:OAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
|
||||
|
||||
### [CodMate](https://github.com/loocor/CodMate)
|
||||
|
||||
CLI AIセッション(Codex、Claude Code、Gemini CLI)を管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
|
||||
|
||||
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||
|
||||
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
|
||||
|
||||
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||
|
||||
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
|
||||
|
||||
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
|
||||
|
||||
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
|
||||
|
||||
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||
|
||||
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
|
||||
|
||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||
|
||||
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替え(Main / Plus)、自動ダウンロードおよび自動更新に対応
|
||||
|
||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||
|
||||
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
|
||||
|
||||
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
|
||||
|
||||
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
|
||||
|
||||
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||
|
||||
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
|
||||
|
||||
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
|
||||
|
||||
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LAN(ローカルエリアネットワーク)を介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
## その他の選択肢
|
||||
|
||||
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです:
|
||||
|
||||
### [9Router](https://github.com/decolua/9router)
|
||||
|
||||
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換(OpenAI/Claude/Gemini/Ollama)、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツール(Cursor、Claude Code、Cline、RooCode)のサポートをゼロから構築 - APIキー不要
|
||||
|
||||
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
|
||||
|
||||
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
|
||||
|
||||
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:スマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
## ライセンス
|
||||
|
||||
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。
|
||||
BIN
assets/lingtrue.png
Normal file
BIN
assets/lingtrue.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
20
cmd/mcpdebug/main.go
Normal file
20
cmd/mcpdebug/main.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Encode MCP result with empty execId
|
||||
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
|
||||
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
|
||||
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
|
||||
|
||||
// Write to file for analysis
|
||||
os.WriteFile("mcp_result.bin", resultBytes)
|
||||
fmt.Println("Wrote mcp_result.bin")
|
||||
}
|
||||
32
cmd/protocheck/main.go
Normal file
32
cmd/protocheck/main.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ecm := cursorproto.NewMsg("ExecClientMessage")
|
||||
|
||||
// Try different field names
|
||||
names := []string{
|
||||
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
|
||||
"shell_result", "shellResult",
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
fd := ecm.Descriptor().Fields().ByName(name)
|
||||
if fd != nil {
|
||||
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
|
||||
} else {
|
||||
fmt.Printf("Field %q NOT FOUND\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
// List all fields
|
||||
fmt.Println("\nAll fields in ExecClientMessage:")
|
||||
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
|
||||
f := ecm.Descriptor().Fields().Get(i)
|
||||
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,7 @@ func main() {
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kimiLogin bool
|
||||
var cursorLogin bool
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
var kiroAWSLogin bool
|
||||
@@ -123,6 +124,7 @@ func main() {
|
||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
|
||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
@@ -544,6 +546,8 @@ func main() {
|
||||
cmd.DoGitLabTokenLogin(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if cursorLogin {
|
||||
cmd.DoCursorLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
|
||||
@@ -96,6 +96,7 @@ max-retry-interval: 30
|
||||
quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||
antigravity-credits: true # Whether to retry Antigravity quota_exhausted 429s once with enabledCreditTypes=["GOOGLE_ONE_AI"]
|
||||
|
||||
# Routing strategy for selecting credentials when multiple match.
|
||||
routing:
|
||||
@@ -177,6 +178,8 @@ nonstream-keepalive-interval: 0
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||
# experimental-cch-signing: false # optional: default is false; when true, sign the final /v1/messages body using the current Claude Code cch algorithm
|
||||
# # keep this disabled unless you explicitly need the behavior, so upstream seed changes fall back to legacy proxy behavior
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
||||
@@ -313,6 +316,10 @@ nonstream-keepalive-interval: 0
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||
# you select the protocol surface, but inference backend selection can still follow the resolved
|
||||
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
# oauth-model-alias:
|
||||
# antigravity:
|
||||
|
||||
1
go.mod
1
go.mod
@@ -83,6 +83,7 @@ require (
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pierrec/xxHash v0.1.5
|
||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -154,6 +154,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pierrec/xxHash v0.1.5 h1:n/jBpwTHiER4xYvK3/CdPVnLDPchj8eTJFFLUb4QHBo=
|
||||
github.com/pierrec/xxHash v0.1.5/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
|
||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
@@ -1046,6 +1047,7 @@ func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Aut
|
||||
auth.Runtime = existing.Runtime
|
||||
}
|
||||
}
|
||||
coreauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
@@ -1128,7 +1130,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, headers, priority, note) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
@@ -1136,11 +1138,12 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
Note *string `json:"note"`
|
||||
Name string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Priority *int `json:"priority"`
|
||||
Note *string `json:"note"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
@@ -1176,13 +1179,107 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
|
||||
changed := false
|
||||
if req.Prefix != nil {
|
||||
targetAuth.Prefix = *req.Prefix
|
||||
prefix := strings.TrimSpace(*req.Prefix)
|
||||
targetAuth.Prefix = prefix
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if prefix == "" {
|
||||
delete(targetAuth.Metadata, "prefix")
|
||||
} else {
|
||||
targetAuth.Metadata["prefix"] = prefix
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if req.ProxyURL != nil {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
proxyURL := strings.TrimSpace(*req.ProxyURL)
|
||||
targetAuth.ProxyURL = proxyURL
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if proxyURL == "" {
|
||||
delete(targetAuth.Metadata, "proxy_url")
|
||||
} else {
|
||||
targetAuth.Metadata["proxy_url"] = proxyURL
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if len(req.Headers) > 0 {
|
||||
existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(targetAuth.Metadata)
|
||||
nextHeaders := make(map[string]string, len(existingHeaders))
|
||||
for k, v := range existingHeaders {
|
||||
nextHeaders[k] = v
|
||||
}
|
||||
headerChanged := false
|
||||
|
||||
for key, value := range req.Headers {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(value)
|
||||
attrKey := "header:" + name
|
||||
if val == "" {
|
||||
if _, ok := nextHeaders[name]; ok {
|
||||
delete(nextHeaders, name)
|
||||
headerChanged = true
|
||||
}
|
||||
if targetAuth.Attributes != nil {
|
||||
if _, ok := targetAuth.Attributes[attrKey]; ok {
|
||||
headerChanged = true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if prev, ok := nextHeaders[name]; !ok || prev != val {
|
||||
headerChanged = true
|
||||
}
|
||||
nextHeaders[name] = val
|
||||
if targetAuth.Attributes != nil {
|
||||
if prev, ok := targetAuth.Attributes[attrKey]; !ok || prev != val {
|
||||
headerChanged = true
|
||||
}
|
||||
} else {
|
||||
headerChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if headerChanged {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if targetAuth.Attributes == nil {
|
||||
targetAuth.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
for key, value := range req.Headers {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(value)
|
||||
attrKey := "header:" + name
|
||||
if val == "" {
|
||||
delete(nextHeaders, name)
|
||||
delete(targetAuth.Attributes, attrKey)
|
||||
continue
|
||||
}
|
||||
nextHeaders[name] = val
|
||||
targetAuth.Attributes[attrKey] = val
|
||||
}
|
||||
|
||||
if len(nextHeaders) == 0 {
|
||||
delete(targetAuth.Metadata, "headers")
|
||||
} else {
|
||||
metaHeaders := make(map[string]any, len(nextHeaders))
|
||||
for k, v := range nextHeaders {
|
||||
metaHeaders[k] = v
|
||||
}
|
||||
targetAuth.Metadata["headers"] = metaHeaders
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if req.Priority != nil || req.Note != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
@@ -2137,9 +2234,6 @@ func (h *Handler) RequestGitLabToken(c *gin.Context) {
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata["oauth_client_id"] = clientID
|
||||
if clientSecret != "" {
|
||||
metadata["oauth_client_secret"] = clientSecret
|
||||
}
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := primaryGitLabEmail(user); email != "" {
|
||||
metadata["email"] = email
|
||||
@@ -3707,3 +3801,84 @@ func (h *Handler) RequestKiloToken(c *gin.Context) {
|
||||
"verification_uri": resp.VerificationURL,
|
||||
})
|
||||
}
|
||||
|
||||
// RequestCursorToken initiates the Cursor PKCE authentication flow.
|
||||
// Supports multiple accounts via ?label=xxx query parameter.
|
||||
// The user opens the returned URL in a browser, logs in, and the server polls
|
||||
// until the authentication completes.
|
||||
func (h *Handler) RequestCursorToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
label := strings.TrimSpace(c.Query("label"))
|
||||
log.Infof("Initializing Cursor authentication (label=%q)...", label)
|
||||
|
||||
authParams, err := cursorauth.GenerateAuthParams()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate Cursor auth params: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate auth params"})
|
||||
return
|
||||
}
|
||||
|
||||
state := fmt.Sprintf("cur-%d", time.Now().UnixNano())
|
||||
RegisterOAuthSession(state, "cursor")
|
||||
|
||||
go func() {
|
||||
log.Info("Waiting for Cursor authentication...")
|
||||
log.Infof("Open this URL in your browser: %s", authParams.LoginURL)
|
||||
|
||||
tokens, errPoll := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
|
||||
if errPoll != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed: "+errPoll.Error())
|
||||
log.Errorf("Cursor authentication failed: %v", errPoll)
|
||||
return
|
||||
}
|
||||
|
||||
// Build metadata
|
||||
metadata := map[string]any{
|
||||
"type": "cursor",
|
||||
"access_token": tokens.AccessToken,
|
||||
"refresh_token": tokens.RefreshToken,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
// Extract expiry and account identity from JWT
|
||||
expiry := cursorauth.GetTokenExpiry(tokens.AccessToken)
|
||||
if !expiry.IsZero() {
|
||||
metadata["expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// Auto-identify account from JWT sub claim for multi-account support
|
||||
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
|
||||
subHash := cursorauth.SubToShortHash(sub)
|
||||
if sub != "" {
|
||||
metadata["sub"] = sub
|
||||
}
|
||||
|
||||
fileName := cursorauth.CredentialFileName(label, subHash)
|
||||
displayLabel := cursorauth.DisplayLabel(label, subHash)
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "cursor",
|
||||
FileName: fileName,
|
||||
Label: displayLabel,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save Cursor tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save tokens")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Cursor authentication successful! Token saved to %s", savedPath)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("cursor")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ok",
|
||||
"url": authParams.LoginURL,
|
||||
"state": state,
|
||||
})
|
||||
}
|
||||
|
||||
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestPatchAuthFileFields_MergeHeadersAndDeleteEmptyValues(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
record := &coreauth.Auth{
|
||||
ID: "test.json",
|
||||
FileName: "test.json",
|
||||
Provider: "claude",
|
||||
Attributes: map[string]string{
|
||||
"path": "/tmp/test.json",
|
||||
"header:X-Old": "old",
|
||||
"header:X-Remove": "gone",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"type": "claude",
|
||||
"headers": map[string]any{
|
||||
"X-Old": "old",
|
||||
"X-Remove": "gone",
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||
|
||||
body := `{"name":"test.json","prefix":"p1","proxy_url":"http://proxy.local","headers":{"X-Old":"new","X-New":"v","X-Remove":" ","X-Nope":""}}`
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx.Request = req
|
||||
h.PatchAuthFileFields(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID("test.json")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth record to exist after patch")
|
||||
}
|
||||
|
||||
if updated.Prefix != "p1" {
|
||||
t.Fatalf("prefix = %q, want %q", updated.Prefix, "p1")
|
||||
}
|
||||
if updated.ProxyURL != "http://proxy.local" {
|
||||
t.Fatalf("proxy_url = %q, want %q", updated.ProxyURL, "http://proxy.local")
|
||||
}
|
||||
|
||||
if updated.Metadata == nil {
|
||||
t.Fatalf("expected metadata to be non-nil")
|
||||
}
|
||||
if got, _ := updated.Metadata["prefix"].(string); got != "p1" {
|
||||
t.Fatalf("metadata.prefix = %q, want %q", got, "p1")
|
||||
}
|
||||
if got, _ := updated.Metadata["proxy_url"].(string); got != "http://proxy.local" {
|
||||
t.Fatalf("metadata.proxy_url = %q, want %q", got, "http://proxy.local")
|
||||
}
|
||||
|
||||
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||
if !ok {
|
||||
raw, _ := json.Marshal(updated.Metadata["headers"])
|
||||
t.Fatalf("metadata.headers = %T (%s), want map[string]any", updated.Metadata["headers"], string(raw))
|
||||
}
|
||||
if got := headersMeta["X-Old"]; got != "new" {
|
||||
t.Fatalf("metadata.headers.X-Old = %#v, want %q", got, "new")
|
||||
}
|
||||
if got := headersMeta["X-New"]; got != "v" {
|
||||
t.Fatalf("metadata.headers.X-New = %#v, want %q", got, "v")
|
||||
}
|
||||
if _, ok := headersMeta["X-Remove"]; ok {
|
||||
t.Fatalf("expected metadata.headers.X-Remove to be deleted")
|
||||
}
|
||||
if _, ok := headersMeta["X-Nope"]; ok {
|
||||
t.Fatalf("expected metadata.headers.X-Nope to be absent")
|
||||
}
|
||||
|
||||
if got := updated.Attributes["header:X-Old"]; got != "new" {
|
||||
t.Fatalf("attrs header:X-Old = %q, want %q", got, "new")
|
||||
}
|
||||
if got := updated.Attributes["header:X-New"]; got != "v" {
|
||||
t.Fatalf("attrs header:X-New = %q, want %q", got, "v")
|
||||
}
|
||||
if _, ok := updated.Attributes["header:X-Remove"]; ok {
|
||||
t.Fatalf("expected attrs header:X-Remove to be deleted")
|
||||
}
|
||||
if _, ok := updated.Attributes["header:X-Nope"]; ok {
|
||||
t.Fatalf("expected attrs header:X-Nope to be absent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
record := &coreauth.Auth{
|
||||
ID: "noop.json",
|
||||
FileName: "noop.json",
|
||||
Provider: "claude",
|
||||
Attributes: map[string]string{
|
||||
"path": "/tmp/noop.json",
|
||||
"header:X-Kee": "1",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"type": "claude",
|
||||
"headers": map[string]any{
|
||||
"X-Kee": "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||
|
||||
body := `{"name":"noop.json","note":"hello","headers":{}}`
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx.Request = req
|
||||
h.PatchAuthFileFields(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID("noop.json")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth record to exist after patch")
|
||||
}
|
||||
if got := updated.Attributes["header:X-Kee"]; got != "1" {
|
||||
t.Fatalf("attrs header:X-Kee = %q, want %q", got, "1")
|
||||
}
|
||||
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected metadata.headers to remain a map, got %T", updated.Metadata["headers"])
|
||||
}
|
||||
if got := headersMeta["X-Kee"]; got != "1" {
|
||||
t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1")
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
)
|
||||
|
||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
|
||||
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
|
||||
if len(apiWebsocketTimeline) > 0 {
|
||||
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
|
||||
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
|
||||
if !isExist {
|
||||
return nil
|
||||
}
|
||||
data, ok := apiTimeline.([]byte)
|
||||
if !ok || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return bytes.Clone(data)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
||||
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
||||
if !isExist {
|
||||
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
if c != nil {
|
||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
|
||||
return body
|
||||
}
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return w.requestInfo.Body
|
||||
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
|
||||
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
|
||||
return body
|
||||
}
|
||||
if w.body == nil || w.body.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
return bytes.Clone(w.body.Bytes())
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
|
||||
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
|
||||
}
|
||||
|
||||
func extractBodyOverride(c *gin.Context, key string) []byte {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
bodyOverride, isExist := c.Get(key)
|
||||
if !isExist {
|
||||
return nil
|
||||
}
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
||||
statusCode,
|
||||
headers,
|
||||
body,
|
||||
websocketTimeline,
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiWebsocketTimeline,
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
w.requestInfo.RequestID,
|
||||
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
||||
statusCode,
|
||||
headers,
|
||||
body,
|
||||
websocketTimeline,
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiWebsocketTimeline,
|
||||
apiResponseErrors,
|
||||
w.requestInfo.RequestID,
|
||||
w.requestInfo.Timestamp,
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||
wrapper.body.WriteString("original-response")
|
||||
|
||||
body := wrapper.extractResponseBody(c)
|
||||
if string(body) != "original-response" {
|
||||
t.Fatalf("response body = %q, want %q", string(body), "original-response")
|
||||
}
|
||||
|
||||
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
|
||||
body = wrapper.extractResponseBody(c)
|
||||
if string(body) != "override-response" {
|
||||
t.Fatalf("response body = %q, want %q", string(body), "override-response")
|
||||
}
|
||||
|
||||
body[0] = 'X'
|
||||
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
|
||||
t.Fatalf("response override should be cloned, got %q", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
|
||||
|
||||
body := wrapper.extractResponseBody(c)
|
||||
if string(body) != "override-response-as-string" {
|
||||
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
override := []byte("body-override")
|
||||
c.Set(requestBodyOverrideContextKey, override)
|
||||
|
||||
body := extractBodyOverride(c, requestBodyOverrideContextKey)
|
||||
if !bytes.Equal(body, override) {
|
||||
t.Fatalf("body override = %q, want %q", string(body), string(override))
|
||||
}
|
||||
|
||||
body[0] = 'X'
|
||||
if !bytes.Equal(override, []byte("body-override")) {
|
||||
t.Fatalf("override mutated: %q", string(override))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
if got := wrapper.extractWebsocketTimeline(c); got != nil {
|
||||
t.Fatalf("expected nil websocket timeline, got %q", string(got))
|
||||
}
|
||||
|
||||
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
|
||||
body := wrapper.extractWebsocketTimeline(c)
|
||||
if string(body) != "timeline" {
|
||||
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
streamWriter := &testStreamingLogWriter{}
|
||||
wrapper := &ResponseWriterWrapper{
|
||||
ResponseWriter: c.Writer,
|
||||
logger: &testRequestLogger{enabled: true},
|
||||
requestInfo: &RequestInfo{
|
||||
URL: "/v1/responses",
|
||||
Method: "POST",
|
||||
Headers: map[string][]string{"Content-Type": {"application/json"}},
|
||||
RequestID: "req-1",
|
||||
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
|
||||
},
|
||||
isStreaming: true,
|
||||
streamWriter: streamWriter,
|
||||
}
|
||||
|
||||
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
|
||||
|
||||
if err := wrapper.Finalize(c); err != nil {
|
||||
t.Fatalf("Finalize error: %v", err)
|
||||
}
|
||||
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
|
||||
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
|
||||
}
|
||||
if !streamWriter.closed {
|
||||
t.Fatal("expected stream writer to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
type testRequestLogger struct {
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
|
||||
return &testStreamingLogWriter{}, nil
|
||||
}
|
||||
|
||||
func (l *testRequestLogger) IsEnabled() bool {
|
||||
return l.enabled
|
||||
}
|
||||
|
||||
type testStreamingLogWriter struct {
|
||||
apiWebsocketTimeline []byte
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
|
||||
|
||||
func (w *testStreamingLogWriter) Close() error {
|
||||
w.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
return
|
||||
}
|
||||
|
||||
// Sanitize request body: remove thinking blocks with invalid signatures
|
||||
// to prevent upstream API 400 errors
|
||||
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
|
||||
|
||||
// Restore the body for the handler to read
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
@@ -249,6 +253,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
rewriter.suppressThinking = true
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
@@ -259,10 +264,17 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
} else if len(providers) > 0 {
|
||||
// Log: Using local provider (free)
|
||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||
// Wrap with ResponseRewriter for local providers too, because upstream
|
||||
// proxies (e.g. NewAPI) may return a different model name and lack
|
||||
// Amp-required fields like thinking.signature.
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
rewriter.suppressThinking = providerName != "claude"
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
rewriter.Flush()
|
||||
} else {
|
||||
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
@@ -2,6 +2,7 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -12,15 +13,17 @@ import (
|
||||
)
|
||||
|
||||
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
||||
// It's used to rewrite model names in responses when model mapping is used
|
||||
// It is used to rewrite model names in responses when model mapping is used
|
||||
// and to keep Amp-compatible response shapes.
|
||||
type ResponseRewriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
originalModel string
|
||||
isStreaming bool
|
||||
body *bytes.Buffer
|
||||
originalModel string
|
||||
isStreaming bool
|
||||
suppressThinking bool
|
||||
}
|
||||
|
||||
// NewResponseRewriter creates a new response rewriter for model name substitution
|
||||
// NewResponseRewriter creates a new response rewriter for model name substitution.
|
||||
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
||||
return &ResponseRewriter{
|
||||
ResponseWriter: w,
|
||||
@@ -33,15 +36,15 @@ const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
||||
|
||||
func looksLikeSSEChunk(data []byte) bool {
|
||||
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
||||
// Heuristics are intentionally simple and cheap.
|
||||
return bytes.Contains(data, []byte("data:")) ||
|
||||
bytes.Contains(data, []byte("event:")) ||
|
||||
bytes.Contains(data, []byte("message_start")) ||
|
||||
bytes.Contains(data, []byte("message_delta")) ||
|
||||
bytes.Contains(data, []byte("content_block_start")) ||
|
||||
bytes.Contains(data, []byte("content_block_delta")) ||
|
||||
bytes.Contains(data, []byte("content_block_stop")) ||
|
||||
bytes.Contains(data, []byte("\n\n"))
|
||||
// We conservatively detect SSE by checking for "data:" / "event:" at the start of any line.
|
||||
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if bytes.HasPrefix(trimmed, []byte("data:")) ||
|
||||
bytes.HasPrefix(trimmed, []byte("event:")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||
@@ -95,7 +98,8 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
rewritten := rw.rewriteStreamChunk(data)
|
||||
n, err := rw.ResponseWriter.Write(rewritten)
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
@@ -106,7 +110,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
// Flush writes the buffered response with model names rewritten
|
||||
func (rw *ResponseRewriter) Flush() {
|
||||
if rw.isStreaming {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
@@ -115,40 +118,79 @@ func (rw *ResponseRewriter) Flush() {
|
||||
return
|
||||
}
|
||||
if rw.body.Len() > 0 {
|
||||
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
||||
rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||
// Update Content-Length to match the rewritten body size, since
|
||||
// signature injection and model name changes alter the payload length.
|
||||
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
|
||||
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
|
||||
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
||||
// The Amp client struggles when both thinking and tool_use blocks are present
|
||||
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||
func ensureAmpSignature(data []byte) []byte {
|
||||
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType != "tool_use" && blockType != "thinking" {
|
||||
continue
|
||||
}
|
||||
signaturePath := fmt.Sprintf("content.%d.signature", index)
|
||||
if gjson.GetBytes(data, signaturePath).Exists() {
|
||||
continue
|
||||
}
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, signaturePath, "")
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
|
||||
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content_block.signature", "")
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||
if !rw.suppressThinking {
|
||||
return data
|
||||
}
|
||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||
if filtered.Exists() {
|
||||
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||
filteredCount := filtered.Get("#").Int()
|
||||
|
||||
if originalCount > filteredCount {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||
} else {
|
||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
||||
// Log the result for verification
|
||||
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
data = ensureAmpSignature(data)
|
||||
data = rw.suppressAmpThinking(data)
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
if rw.originalModel == "" {
|
||||
return data
|
||||
}
|
||||
@@ -160,24 +202,154 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
||||
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||
if rw.originalModel == "" {
|
||||
return chunk
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
var out [][]byte
|
||||
|
||||
i := 0
|
||||
for i < len(lines) {
|
||||
line := lines[i]
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
|
||||
// Case 1: "event:" line - look ahead for its "data:" line
|
||||
if bytes.HasPrefix(trimmed, []byte("event: ")) {
|
||||
// Scan forward past blank lines to find the data: line
|
||||
dataIdx := -1
|
||||
for j := i + 1; j < len(lines); j++ {
|
||||
t := bytes.TrimSpace(lines[j])
|
||||
if len(t) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(t, []byte("data: ")) {
|
||||
dataIdx = j
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if dataIdx >= 0 {
|
||||
// Found event+data pair - process through rewriter
|
||||
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||
if rewritten == nil {
|
||||
i = dataIdx + 1
|
||||
continue
|
||||
}
|
||||
// Emit event line
|
||||
out = append(out, line)
|
||||
// Emit blank lines between event and data
|
||||
for k := i + 1; k < dataIdx; k++ {
|
||||
out = append(out, lines[k])
|
||||
}
|
||||
// Emit rewritten data
|
||||
out = append(out, append([]byte("data: "), rewritten...))
|
||||
i = dataIdx + 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// No data line found (orphan event from cross-chunk split)
|
||||
// Pass it through as-is - the data will arrive in the next chunk
|
||||
out = append(out, line)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Case 2: standalone "data:" line (no preceding event: in this chunk)
|
||||
if bytes.HasPrefix(trimmed, []byte("data: ")) {
|
||||
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||
if rewritten != nil {
|
||||
out = append(out, append([]byte("data: "), rewritten...))
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: everything else
|
||||
out = append(out, line)
|
||||
i++
|
||||
}
|
||||
|
||||
// SSE format: "data: {json}\n\n"
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i, line := range lines {
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
// Rewrite JSON in the data line
|
||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||
lines[i] = append([]byte("data: "), rewritten...)
|
||||
return bytes.Join(out, []byte("\n"))
|
||||
}
|
||||
|
||||
// rewriteStreamEvent processes a single JSON event in the SSE stream.
|
||||
// It rewrites model names and ensures signature fields exist.
|
||||
// NOTE: streaming mode does NOT suppress thinking blocks - they are
|
||||
// passed through with signature injection to avoid breaking SSE index
|
||||
// alignment and TUI rendering.
|
||||
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||
// Inject empty signature where needed
|
||||
data = ensureAmpSignature(data)
|
||||
|
||||
// Rewrite model name
|
||||
if rw.originalModel != "" {
|
||||
for _, path := range modelFieldPaths {
|
||||
if gjson.GetBytes(data, path).Exists() {
|
||||
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
return data
|
||||
}
|
||||
|
||||
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
|
||||
// from the messages array in a request body before forwarding to the upstream API.
|
||||
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
|
||||
func SanitizeAmpRequestBody(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
for msgIdx, msg := range messages.Array() {
|
||||
if msg.Get("role").String() != "assistant" {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
var keepBlocks []interface{}
|
||||
removedCount := 0
|
||||
|
||||
for _, block := range content.Array() {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType == "thinking" {
|
||||
sig := block.Get("signature")
|
||||
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
|
||||
removedCount++
|
||||
continue
|
||||
}
|
||||
}
|
||||
keepBlocks = append(keepBlocks, block.Value())
|
||||
}
|
||||
|
||||
if removedCount > 0 {
|
||||
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
|
||||
var err error
|
||||
if len(keepBlocks) == 0 {
|
||||
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
|
||||
} else {
|
||||
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
|
||||
continue
|
||||
}
|
||||
modified = true
|
||||
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
|
||||
}
|
||||
}
|
||||
|
||||
if modified {
|
||||
log.Debugf("Amp RequestSanitizer: sanitized request body")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -100,6 +101,50 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) {
|
||||
rw := &ResponseRewriter{}
|
||||
|
||||
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
// Streaming mode preserves thinking blocks (does NOT suppress them)
|
||||
// to avoid breaking SSE index alignment and TUI rendering
|
||||
if !contains(result, []byte(`"content_block":{"type":"thinking"`)) {
|
||||
t.Fatalf("expected thinking content_block_start to be preserved, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte(`"delta":{"type":"thinking_delta"`)) {
|
||||
t.Fatalf("expected thinking_delta to be preserved, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte(`"type":"content_block_stop","index":0`)) {
|
||||
t.Fatalf("expected content_block_stop for thinking block to be preserved, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte(`"content_block":{"type":"tool_use"`)) {
|
||||
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
|
||||
}
|
||||
// Signature should be injected into both thinking and tool_use blocks
|
||||
if count := strings.Count(string(result), `"signature":""`); count != 2 {
|
||||
t.Fatalf("expected 2 signature injections, but got %d in %s", count, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
|
||||
result := SanitizeAmpRequestBody(input)
|
||||
|
||||
if contains(result, []byte("drop-whitespace")) {
|
||||
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
|
||||
}
|
||||
if contains(result, []byte("drop-number")) {
|
||||
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte("keep-valid")) {
|
||||
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte("keep-text")) {
|
||||
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
|
||||
@@ -323,6 +323,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// setupRoutes configures the API routes for the server.
|
||||
// It defines the endpoints and associates them with their respective handlers.
|
||||
func (s *Server) setupRoutes() {
|
||||
s.engine.GET("/healthz", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
||||
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
||||
@@ -682,6 +686,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
|
||||
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -46,6 +47,28 @@ func newTestServer(t *testing.T) *Server {
|
||||
return NewServer(cfg, authManager, accessManager, configPath)
|
||||
}
|
||||
|
||||
func TestHealthz(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String())
|
||||
}
|
||||
if resp.Status != "ok" {
|
||||
t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpProviderModelRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -172,6 +195,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
"issue-1711",
|
||||
time.Now(),
|
||||
|
||||
@@ -88,7 +88,7 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
||||
"client_id": {ClientID},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"scope": {"org:create_api_key user:profile user:inference"},
|
||||
"scope": {"user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"},
|
||||
"code_challenge": {pkceCodes.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
|
||||
33
internal/auth/cursor/filename.go
Normal file
33
internal/auth/cursor/filename.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Cursor credentials.
|
||||
// Priority: explicit label > auto-generated from JWT sub hash.
|
||||
// If both label and subHash are empty, falls back to "cursor.json".
|
||||
func CredentialFileName(label, subHash string) string {
|
||||
label = strings.TrimSpace(label)
|
||||
subHash = strings.TrimSpace(subHash)
|
||||
if label != "" {
|
||||
return fmt.Sprintf("cursor.%s.json", label)
|
||||
}
|
||||
if subHash != "" {
|
||||
return fmt.Sprintf("cursor.%s.json", subHash)
|
||||
}
|
||||
return "cursor.json"
|
||||
}
|
||||
|
||||
// DisplayLabel returns a human-readable label for the Cursor account.
|
||||
func DisplayLabel(label, subHash string) string {
|
||||
label = strings.TrimSpace(label)
|
||||
if label != "" {
|
||||
return "Cursor " + label
|
||||
}
|
||||
if subHash != "" {
|
||||
return "Cursor " + subHash
|
||||
}
|
||||
return "Cursor User"
|
||||
}
|
||||
249
internal/auth/cursor/oauth.go
Normal file
249
internal/auth/cursor/oauth.go
Normal file
@@ -0,0 +1,249 @@
|
||||
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CursorLoginURL = "https://cursor.com/loginDeepControl"
|
||||
CursorPollURL = "https://api2.cursor.sh/auth/poll"
|
||||
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
|
||||
|
||||
pollMaxAttempts = 150
|
||||
pollBaseDelay = 1 * time.Second
|
||||
pollMaxDelay = 10 * time.Second
|
||||
pollBackoffMultiply = 1.2
|
||||
maxConsecutiveErrors = 10
|
||||
)
|
||||
|
||||
// AuthParams holds the PKCE parameters for Cursor login.
|
||||
type AuthParams struct {
|
||||
Verifier string
|
||||
Challenge string
|
||||
UUID string
|
||||
LoginURL string
|
||||
}
|
||||
|
||||
// TokenPair holds the access and refresh tokens from Cursor.
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// GeneratePKCE creates a PKCE verifier and challenge pair.
|
||||
func GeneratePKCE() (verifier, challenge string, err error) {
|
||||
verifierBytes := make([]byte, 96)
|
||||
if _, err = rand.Read(verifierBytes); err != nil {
|
||||
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// GenerateAuthParams creates the full set of auth params for Cursor login.
|
||||
func GenerateAuthParams() (*AuthParams, error) {
|
||||
verifier, challenge, err := GeneratePKCE()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uuidBytes := make([]byte, 16)
|
||||
if _, err = rand.Read(uuidBytes); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
|
||||
}
|
||||
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
|
||||
|
||||
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
|
||||
CursorLoginURL, challenge, uuid)
|
||||
|
||||
return &AuthParams{
|
||||
Verifier: verifier,
|
||||
Challenge: challenge,
|
||||
UUID: uuid,
|
||||
LoginURL: loginURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PollForAuth polls the Cursor auth endpoint until the user completes login.
|
||||
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
|
||||
delay := pollBaseDelay
|
||||
consecutiveErrors := 0
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
consecutiveErrors++
|
||||
if consecutiveErrors >= maxConsecutiveErrors {
|
||||
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
|
||||
}
|
||||
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// Still waiting for user to authorize
|
||||
consecutiveErrors = 0
|
||||
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
var tokens TokenPair
|
||||
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
|
||||
}
|
||||
return &tokens, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
|
||||
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
|
||||
}
|
||||
|
||||
// RefreshToken refreshes a Cursor access token using the refresh token.
|
||||
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
|
||||
strings.NewReader("{}"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+refreshToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokens TokenPair
|
||||
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Keep original refresh token if not returned
|
||||
if tokens.RefreshToken == "" {
|
||||
tokens.RefreshToken = refreshToken
|
||||
}
|
||||
|
||||
return &tokens, nil
|
||||
}
|
||||
|
||||
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
|
||||
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
|
||||
// the account. Returns empty string if parsing fails.
|
||||
func ParseJWTSub(token string) string {
|
||||
decoded := decodeJWTPayload(token)
|
||||
if decoded == nil {
|
||||
return ""
|
||||
}
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
return claims.Sub
|
||||
}
|
||||
|
||||
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
|
||||
// e.g. "auth0|user_2x..." → "a3f8b2c1"
|
||||
func SubToShortHash(sub string) string {
|
||||
if sub == "" {
|
||||
return ""
|
||||
}
|
||||
h := sha256.Sum256([]byte(sub))
|
||||
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
|
||||
}
|
||||
|
||||
// decodeJWTPayload decodes the payload (middle) part of a JWT.
|
||||
func decodeJWTPayload(token string) []byte {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil
|
||||
}
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
payload = strings.ReplaceAll(payload, "-", "+")
|
||||
payload = strings.ReplaceAll(payload, "_", "/")
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return decoded
|
||||
}
|
||||
|
||||
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
|
||||
// Falls back to 1 hour from now if the token can't be parsed.
|
||||
func GetTokenExpiry(token string) time.Time {
|
||||
decoded := decodeJWTPayload(token)
|
||||
if decoded == nil {
|
||||
return time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Exp float64 `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
|
||||
return time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
sec, frac := math.Modf(claims.Exp)
|
||||
expiry := time.Unix(int64(sec), int64(frac*1e9))
|
||||
// Subtract 5-minute safety margin
|
||||
return expiry.Add(-5 * time.Minute)
|
||||
}
|
||||
|
||||
func minDuration(a, b time.Duration) time.Duration {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
84
internal/auth/cursor/proto/connect.go
Normal file
84
internal/auth/cursor/proto/connect.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
|
||||
ConnectEndStreamFlag byte = 0x02
|
||||
// ConnectCompressionFlag indicates the payload is compressed (not supported).
|
||||
ConnectCompressionFlag byte = 0x01
|
||||
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
|
||||
ConnectFrameHeaderSize = 5
|
||||
)
|
||||
|
||||
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
|
||||
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
|
||||
func FrameConnectMessage(data []byte, flags byte) []byte {
|
||||
frame := make([]byte, ConnectFrameHeaderSize+len(data))
|
||||
frame[0] = flags
|
||||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
|
||||
copy(frame[5:], data)
|
||||
return frame
|
||||
}
|
||||
|
||||
// ParseConnectFrame extracts one frame from a buffer.
|
||||
// Returns (flags, payload, bytesConsumed, ok).
|
||||
// ok is false when the buffer is too short for a complete frame.
|
||||
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
|
||||
if len(buf) < ConnectFrameHeaderSize {
|
||||
return 0, nil, 0, false
|
||||
}
|
||||
flags = buf[0]
|
||||
length := binary.BigEndian.Uint32(buf[1:5])
|
||||
total := ConnectFrameHeaderSize + int(length)
|
||||
if len(buf) < total {
|
||||
return 0, nil, 0, false
|
||||
}
|
||||
return flags, buf[5:total], total, true
|
||||
}
|
||||
|
||||
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
|
||||
// The Code field contains the server-defined error code (e.g. gRPC standard codes
|
||||
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
|
||||
type ConnectError struct {
|
||||
Code string // server-defined error code
|
||||
Message string // human-readable error description
|
||||
}
|
||||
|
||||
func (e *ConnectError) Error() string {
|
||||
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
|
||||
// Returns nil if there is no error in the trailer.
|
||||
// On error, returns a *ConnectError with the server's error code and message.
|
||||
func ParseConnectEndStream(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var trailer struct {
|
||||
Error *struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &trailer); err != nil {
|
||||
return fmt.Errorf("failed to parse Connect end stream: %w", err)
|
||||
}
|
||||
if trailer.Error != nil {
|
||||
code := trailer.Error.Code
|
||||
if code == "" {
|
||||
code = "unknown"
|
||||
}
|
||||
msg := trailer.Error.Message
|
||||
if msg == "" {
|
||||
msg = "Unknown error"
|
||||
}
|
||||
return &ConnectError{Code: code, Message: msg}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
564
internal/auth/cursor/proto/decode.go
Normal file
564
internal/auth/cursor/proto/decode.go
Normal file
@@ -0,0 +1,564 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
// ServerMessageType identifies the kind of decoded server message.
|
||||
type ServerMessageType int
|
||||
|
||||
const (
|
||||
ServerMsgUnknown ServerMessageType = iota
|
||||
ServerMsgTextDelta // Text content delta
|
||||
ServerMsgThinkingDelta // Thinking/reasoning delta
|
||||
ServerMsgThinkingCompleted // Thinking completed
|
||||
ServerMsgKvGetBlob // Server wants a blob
|
||||
ServerMsgKvSetBlob // Server wants to store a blob
|
||||
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
|
||||
ServerMsgExecMcpArgs // Server wants MCP tool execution
|
||||
ServerMsgExecShellArgs // Rejected: shell command
|
||||
ServerMsgExecReadArgs // Rejected: file read
|
||||
ServerMsgExecWriteArgs // Rejected: file write
|
||||
ServerMsgExecDeleteArgs // Rejected: file delete
|
||||
ServerMsgExecLsArgs // Rejected: directory listing
|
||||
ServerMsgExecGrepArgs // Rejected: grep search
|
||||
ServerMsgExecFetchArgs // Rejected: HTTP fetch
|
||||
ServerMsgExecDiagnostics // Respond with empty diagnostics
|
||||
ServerMsgExecShellStream // Rejected: shell stream
|
||||
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||
ServerMsgExecOther // Other exec types (respond with empty)
|
||||
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||
ServerMsgHeartbeat // Server heartbeat
|
||||
ServerMsgTokenDelta // Token usage delta
|
||||
ServerMsgCheckpoint // Conversation checkpoint update
|
||||
)
|
||||
|
||||
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||
type DecodedServerMessage struct {
|
||||
Type ServerMessageType
|
||||
|
||||
// For text/thinking deltas
|
||||
Text string
|
||||
|
||||
// For KV messages
|
||||
KvId uint32
|
||||
BlobId []byte // hex-encoded blob ID
|
||||
BlobData []byte // for setBlobArgs
|
||||
|
||||
// For exec messages
|
||||
ExecMsgId uint32
|
||||
ExecId string
|
||||
|
||||
// For MCP args
|
||||
McpToolName string
|
||||
McpToolCallId string
|
||||
McpArgs map[string][]byte // arg name -> protobuf-encoded value
|
||||
|
||||
// For rejection context
|
||||
Path string
|
||||
Command string
|
||||
WorkingDirectory string
|
||||
Url string
|
||||
|
||||
// For other exec - the raw field number for building a response
|
||||
ExecFieldNumber int
|
||||
|
||||
// For TokenDeltaUpdate
|
||||
TokenDelta int64
|
||||
|
||||
// For conversation checkpoint update (raw bytes, not decoded)
|
||||
CheckpointData []byte
|
||||
}
|
||||
|
||||
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
||||
// a structured representation of the first meaningful message found.
|
||||
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
|
||||
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
|
||||
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid tag")
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid bytes field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
// Debug: log top-level ASM fields
|
||||
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
|
||||
|
||||
switch num {
|
||||
case ASM_InteractionUpdate:
|
||||
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
|
||||
decodeInteractionUpdate(val, msg)
|
||||
case ASM_ExecServerMessage:
|
||||
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
|
||||
decodeExecServerMessage(val, msg)
|
||||
case ASM_KvServerMessage:
|
||||
decodeKvServerMessage(val, msg)
|
||||
case ASM_ConversationCheckpoint:
|
||||
msg.Type = ServerMsgCheckpoint
|
||||
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
|
||||
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
|
||||
}
|
||||
|
||||
case protowire.VarintType:
|
||||
_, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid varint field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
default:
|
||||
// Skip unknown wire types
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
|
||||
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||
|
||||
switch num {
|
||||
case IU_TextDelta:
|
||||
msg.Type = ServerMsgTextDelta
|
||||
msg.Text = decodeStringField(val, TDU_Text)
|
||||
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
|
||||
case IU_ThinkingDelta:
|
||||
msg.Type = ServerMsgThinkingDelta
|
||||
msg.Text = decodeStringField(val, TKD_Text)
|
||||
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
|
||||
case IU_ThinkingCompleted:
|
||||
msg.Type = ServerMsgThinkingCompleted
|
||||
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
|
||||
case 2:
|
||||
// tool_call_started - ignore but log
|
||||
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
|
||||
case 3:
|
||||
// tool_call_completed - ignore but log
|
||||
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
||||
case 8:
|
||||
// token_delta - extract token count
|
||||
msg.Type = ServerMsgTokenDelta
|
||||
msg.TokenDelta = decodeVarintField(val, 1)
|
||||
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
|
||||
case 13:
|
||||
// heartbeat from server
|
||||
msg.Type = ServerMsgHeartbeat
|
||||
case 14:
|
||||
// turn_ended - critical: model finished generating
|
||||
msg.Type = ServerMsgTurnEnded
|
||||
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
|
||||
case 16:
|
||||
// step_started - ignore
|
||||
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
|
||||
case 17:
|
||||
// step_completed - ignore
|
||||
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
|
||||
default:
|
||||
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.VarintType:
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == KSM_Id {
|
||||
msg.KvId = uint32(val)
|
||||
}
|
||||
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch num {
|
||||
case KSM_GetBlobArgs:
|
||||
msg.Type = ServerMsgKvGetBlob
|
||||
msg.BlobId = decodeBytesField(val, GBA_BlobId)
|
||||
case KSM_SetBlobArgs:
|
||||
msg.Type = ServerMsgKvSetBlob
|
||||
decodeSetBlobArgs(val, msg)
|
||||
}
|
||||
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
switch num {
|
||||
case SBA_BlobId:
|
||||
msg.BlobId = val
|
||||
case SBA_BlobData:
|
||||
msg.BlobData = val
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.VarintType:
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == ESM_Id {
|
||||
msg.ExecMsgId = uint32(val)
|
||||
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
|
||||
}
|
||||
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
// Debug: log all fields found in ExecServerMessage
|
||||
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||
|
||||
switch num {
|
||||
case ESM_ExecId:
|
||||
msg.ExecId = string(val)
|
||||
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
|
||||
case ESM_RequestContextArgs:
|
||||
msg.Type = ServerMsgExecRequestCtx
|
||||
case ESM_McpArgs:
|
||||
msg.Type = ServerMsgExecMcpArgs
|
||||
decodeMcpArgs(val, msg)
|
||||
case ESM_ShellArgs:
|
||||
msg.Type = ServerMsgExecShellArgs
|
||||
decodeShellArgs(val, msg)
|
||||
case ESM_ShellStreamArgs:
|
||||
msg.Type = ServerMsgExecShellStream
|
||||
decodeShellArgs(val, msg)
|
||||
case ESM_ReadArgs:
|
||||
msg.Type = ServerMsgExecReadArgs
|
||||
msg.Path = decodeStringField(val, RA_Path)
|
||||
case ESM_WriteArgs:
|
||||
msg.Type = ServerMsgExecWriteArgs
|
||||
msg.Path = decodeStringField(val, WA_Path)
|
||||
case ESM_DeleteArgs:
|
||||
msg.Type = ServerMsgExecDeleteArgs
|
||||
msg.Path = decodeStringField(val, DA_Path)
|
||||
case ESM_LsArgs:
|
||||
msg.Type = ServerMsgExecLsArgs
|
||||
msg.Path = decodeStringField(val, LA_Path)
|
||||
case ESM_GrepArgs:
|
||||
msg.Type = ServerMsgExecGrepArgs
|
||||
case ESM_FetchArgs:
|
||||
msg.Type = ServerMsgExecFetchArgs
|
||||
msg.Url = decodeStringField(val, FA_Url)
|
||||
case ESM_DiagnosticsArgs:
|
||||
msg.Type = ServerMsgExecDiagnostics
|
||||
case ESM_BackgroundShellSpawn:
|
||||
msg.Type = ServerMsgExecBgShellSpawn
|
||||
decodeShellArgs(val, msg) // same structure
|
||||
case ESM_WriteShellStdinArgs:
|
||||
msg.Type = ServerMsgExecWriteShellStdin
|
||||
default:
|
||||
// Unknown exec types - only set if we haven't identified the type yet
|
||||
// (other fields like span_context (19) come after the exec type field)
|
||||
if msg.Type == ServerMsgUnknown {
|
||||
msg.Type = ServerMsgExecOther
|
||||
msg.ExecFieldNumber = int(num)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
|
||||
msg.McpArgs = make(map[string][]byte)
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch num {
|
||||
case MCA_Name:
|
||||
msg.McpToolName = string(val)
|
||||
case MCA_Args:
|
||||
// Map entries are encoded as submessages with key=1, value=2
|
||||
decodeMapEntry(val, msg.McpArgs)
|
||||
case MCA_ToolCallId:
|
||||
msg.McpToolCallId = string(val)
|
||||
case MCA_ToolName:
|
||||
// ToolName takes precedence if present
|
||||
if msg.McpToolName == "" || string(val) != "" {
|
||||
msg.McpToolName = string(val)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeMapEntry(data []byte, m map[string][]byte) {
|
||||
var key string
|
||||
var value []byte
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == 1 {
|
||||
key = string(val)
|
||||
} else if num == 2 {
|
||||
value = append([]byte(nil), val...)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
if key != "" {
|
||||
m[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
switch num {
|
||||
case SHA_Command:
|
||||
msg.Command = string(val)
|
||||
case SHA_WorkingDirectory:
|
||||
msg.WorkingDirectory = string(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper decoders ---
|
||||
|
||||
// decodeStringField extracts a string from the first matching field in a submessage.
|
||||
func decodeStringField(data []byte, targetField protowire.Number) string {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return string(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// decodeBytesField extracts bytes from the first matching field in a submessage.
|
||||
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return append([]byte(nil), val...)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
|
||||
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
if typ == protowire.VarintType {
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return int64(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
||||
func BlobIdHex(blobId []byte) string {
|
||||
return hex.EncodeToString(blobId)
|
||||
}
|
||||
|
||||
1244
internal/auth/cursor/proto/descriptor.go
Normal file
1244
internal/auth/cursor/proto/descriptor.go
Normal file
File diff suppressed because it is too large
Load Diff
664
internal/auth/cursor/proto/encode.go
Normal file
664
internal/auth/cursor/proto/encode.go
Normal file
@@ -0,0 +1,664 @@
|
||||
// Package proto provides protobuf encoding for Cursor's gRPC API,
|
||||
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
|
||||
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
|
||||
package proto
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/dynamicpb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
// --- Public types ---
|
||||
|
||||
// RunRequestParams holds all data needed to build an AgentRunRequest.
|
||||
type RunRequestParams struct {
|
||||
ModelId string
|
||||
SystemPrompt string
|
||||
UserText string
|
||||
MessageId string
|
||||
ConversationId string
|
||||
Images []ImageData
|
||||
Turns []TurnData
|
||||
McpTools []McpToolDef
|
||||
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
|
||||
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
MimeType string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type TurnData struct {
|
||||
UserText string
|
||||
AssistantText string
|
||||
}
|
||||
|
||||
type McpToolDef struct {
|
||||
Name string
|
||||
Description string
|
||||
InputSchema json.RawMessage
|
||||
}
|
||||
|
||||
// --- Helper: create a dynamic message and set fields ---
|
||||
|
||||
func newMsg(name string) *dynamicpb.Message {
|
||||
return dynamicpb.NewMessage(Msg(name))
|
||||
}
|
||||
|
||||
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
|
||||
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
|
||||
}
|
||||
|
||||
func setStr(msg *dynamicpb.Message, name, val string) {
|
||||
if val != "" {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
|
||||
}
|
||||
}
|
||||
|
||||
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
|
||||
if len(val) > 0 {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
|
||||
}
|
||||
}
|
||||
|
||||
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
|
||||
}
|
||||
|
||||
func setBool(msg *dynamicpb.Message, name string, val bool) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
|
||||
}
|
||||
|
||||
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
|
||||
}
|
||||
|
||||
func marshal(msg *dynamicpb.Message) []byte {
|
||||
b, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
panic("cursor proto marshal: " + err.Error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// --- Encode functions mirroring cursor-fetch.ts ---
|
||||
|
||||
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
|
||||
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
|
||||
func EncodeHeartbeat() []byte {
|
||||
hb := newMsg("ClientHeartbeat")
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "client_heartbeat", hb)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
|
||||
// Mirrors buildCursorRequest() in cursor-fetch.ts.
|
||||
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
|
||||
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
|
||||
func EncodeRunRequest(p *RunRequestParams) []byte {
|
||||
if p.RawCheckpoint != nil {
|
||||
return encodeRunRequestWithCheckpoint(p)
|
||||
}
|
||||
|
||||
if p.BlobStore == nil {
|
||||
p.BlobStore = make(map[string][]byte)
|
||||
}
|
||||
|
||||
// --- Conversation turns ---
|
||||
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
|
||||
var turnBytes [][]byte
|
||||
for _, turn := range p.Turns {
|
||||
// UserMessage for this turn
|
||||
um := newMsg("UserMessage")
|
||||
setStr(um, "text", turn.UserText)
|
||||
setStr(um, "message_id", generateId())
|
||||
umBytes := marshal(um)
|
||||
|
||||
// Steps (assistant response)
|
||||
var stepBytes [][]byte
|
||||
if turn.AssistantText != "" {
|
||||
am := newMsg("AssistantMessage")
|
||||
setStr(am, "text", turn.AssistantText)
|
||||
step := newMsg("ConversationStep")
|
||||
setMsg(step, "assistant_message", am)
|
||||
stepBytes = append(stepBytes, marshal(step))
|
||||
}
|
||||
|
||||
// AgentConversationTurnStructure (fields are bytes, not submessages)
|
||||
agentTurn := newMsg("AgentConversationTurnStructure")
|
||||
setBytes(agentTurn, "user_message", umBytes)
|
||||
for _, sb := range stepBytes {
|
||||
stepsField := field(agentTurn, "steps")
|
||||
list := agentTurn.Mutable(stepsField).List()
|
||||
list.Append(protoreflect.ValueOfBytes(sb))
|
||||
}
|
||||
|
||||
// ConversationTurnStructure (oneof turn → agentConversationTurn)
|
||||
cts := newMsg("ConversationTurnStructure")
|
||||
setMsg(cts, "agent_conversation_turn", agentTurn)
|
||||
turnBytes = append(turnBytes, marshal(cts))
|
||||
}
|
||||
|
||||
// --- System prompt blob ---
|
||||
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
|
||||
blobId := sha256Sum(systemJSON)
|
||||
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
|
||||
|
||||
// --- ConversationStateStructure ---
|
||||
css := newMsg("ConversationStateStructure")
|
||||
// rootPromptMessagesJson: repeated bytes
|
||||
rootField := field(css, "root_prompt_messages_json")
|
||||
rootList := css.Mutable(rootField).List()
|
||||
rootList.Append(protoreflect.ValueOfBytes(blobId))
|
||||
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
|
||||
turnsField := field(css, "turns")
|
||||
turnsList := css.Mutable(turnsField).List()
|
||||
for _, tb := range turnBytes {
|
||||
turnsList.Append(protoreflect.ValueOfBytes(tb))
|
||||
}
|
||||
turnsOldField := field(css, "turns_old")
|
||||
if turnsOldField != nil {
|
||||
turnsOldList := css.Mutable(turnsOldField).List()
|
||||
for _, tb := range turnBytes {
|
||||
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
|
||||
}
|
||||
}
|
||||
|
||||
// --- UserMessage (current) ---
|
||||
userMessage := newMsg("UserMessage")
|
||||
setStr(userMessage, "text", p.UserText)
|
||||
setStr(userMessage, "message_id", p.MessageId)
|
||||
|
||||
// Images via SelectedContext
|
||||
if len(p.Images) > 0 {
|
||||
sc := newMsg("SelectedContext")
|
||||
imgsField := field(sc, "selected_images")
|
||||
imgsList := sc.Mutable(imgsField).List()
|
||||
for _, img := range p.Images {
|
||||
si := newMsg("SelectedImage")
|
||||
setStr(si, "uuid", generateId())
|
||||
setStr(si, "mime_type", img.MimeType)
|
||||
setBytes(si, "data", img.Data)
|
||||
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||
}
|
||||
setMsg(userMessage, "selected_context", sc)
|
||||
}
|
||||
|
||||
// --- UserMessageAction ---
|
||||
uma := newMsg("UserMessageAction")
|
||||
setMsg(uma, "user_message", userMessage)
|
||||
|
||||
// --- ConversationAction ---
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "user_message_action", uma)
|
||||
|
||||
// --- ModelDetails ---
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
|
||||
// --- AgentRunRequest ---
|
||||
arr := newMsg("AgentRunRequest")
|
||||
setMsg(arr, "conversation_state", css)
|
||||
setMsg(arr, "action", ca)
|
||||
setMsg(arr, "model_details", md)
|
||||
setStr(arr, "conversation_id", p.ConversationId)
|
||||
|
||||
// McpTools
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
setMsg(arr, "mcp_tools", mcpTools)
|
||||
}
|
||||
|
||||
// --- AgentClientMessage ---
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "run_request", arr)
|
||||
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
|
||||
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
|
||||
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
|
||||
// Build UserMessage
|
||||
userMessage := newMsg("UserMessage")
|
||||
setStr(userMessage, "text", p.UserText)
|
||||
setStr(userMessage, "message_id", p.MessageId)
|
||||
if len(p.Images) > 0 {
|
||||
sc := newMsg("SelectedContext")
|
||||
imgsField := field(sc, "selected_images")
|
||||
imgsList := sc.Mutable(imgsField).List()
|
||||
for _, img := range p.Images {
|
||||
si := newMsg("SelectedImage")
|
||||
setStr(si, "uuid", generateId())
|
||||
setStr(si, "mime_type", img.MimeType)
|
||||
setBytes(si, "data", img.Data)
|
||||
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||
}
|
||||
setMsg(userMessage, "selected_context", sc)
|
||||
}
|
||||
|
||||
// Build ConversationAction with UserMessageAction
|
||||
uma := newMsg("UserMessageAction")
|
||||
setMsg(uma, "user_message", userMessage)
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "user_message_action", uma)
|
||||
caBytes := marshal(ca)
|
||||
|
||||
// Build ModelDetails
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
mdBytes := marshal(md)
|
||||
|
||||
// Build McpTools
|
||||
var mcpToolsBytes []byte
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
mcpToolsBytes = marshal(mcpTools)
|
||||
}
|
||||
|
||||
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
|
||||
var arrBuf []byte
|
||||
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
|
||||
// field 2: action = ConversationAction
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
|
||||
// field 3: model_details = ModelDetails
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
|
||||
// field 4: mcp_tools = McpTools
|
||||
if len(mcpToolsBytes) > 0 {
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
|
||||
}
|
||||
// field 5: conversation_id = string
|
||||
if p.ConversationId != "" {
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
|
||||
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
|
||||
}
|
||||
|
||||
// Wrap in AgentClientMessage field 1 (run_request)
|
||||
var acmBuf []byte
|
||||
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
|
||||
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
|
||||
|
||||
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
|
||||
return acmBuf
|
||||
}
|
||||
|
||||
// ResumeRequestParams holds data for a ResumeAction request.
|
||||
type ResumeRequestParams struct {
|
||||
ModelId string
|
||||
ConversationId string
|
||||
McpTools []McpToolDef
|
||||
}
|
||||
|
||||
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
|
||||
// Used to resume a conversation by conversation_id without re-sending full history.
|
||||
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
|
||||
// RequestContext with tools
|
||||
rc := newMsg("RequestContext")
|
||||
if len(p.McpTools) > 0 {
|
||||
toolsField := field(rc, "tools")
|
||||
toolsList := rc.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
}
|
||||
|
||||
// ResumeAction
|
||||
ra := newMsg("ResumeAction")
|
||||
setMsg(ra, "request_context", rc)
|
||||
|
||||
// ConversationAction with resume_action
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "resume_action", ra)
|
||||
|
||||
// ModelDetails
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
|
||||
// AgentRunRequest — no conversation_state needed for resume
|
||||
arr := newMsg("AgentRunRequest")
|
||||
setMsg(arr, "action", ca)
|
||||
setMsg(arr, "model_details", md)
|
||||
setStr(arr, "conversation_id", p.ConversationId)
|
||||
|
||||
// McpTools at top level
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
setMsg(arr, "mcp_tools", mcpTools)
|
||||
}
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "run_request", arr)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// --- KV response encoders ---
|
||||
// Mirrors handleKvMessage() in cursor-fetch.ts
|
||||
|
||||
// EncodeKvGetBlobResult responds to a getBlobArgs request.
|
||||
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
|
||||
result := newMsg("GetBlobResult")
|
||||
if blobData != nil {
|
||||
setBytes(result, "blob_data", blobData)
|
||||
}
|
||||
|
||||
kvc := newMsg("KvClientMessage")
|
||||
setUint32(kvc, "id", kvId)
|
||||
setMsg(kvc, "get_blob_result", result)
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "kv_client_message", kvc)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// EncodeKvSetBlobResult responds to a setBlobArgs request.
|
||||
func EncodeKvSetBlobResult(kvId uint32) []byte {
|
||||
result := newMsg("SetBlobResult")
|
||||
|
||||
kvc := newMsg("KvClientMessage")
|
||||
setUint32(kvc, "id", kvId)
|
||||
setMsg(kvc, "set_blob_result", result)
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "kv_client_message", kvc)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// --- Exec response encoders ---
|
||||
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
|
||||
|
||||
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
|
||||
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
|
||||
// RequestContext with tools
|
||||
rc := newMsg("RequestContext")
|
||||
if len(tools) > 0 {
|
||||
toolsField := field(rc, "tools")
|
||||
toolsList := rc.Mutable(toolsField).List()
|
||||
for _, tool := range tools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
}
|
||||
|
||||
// RequestContextSuccess
|
||||
rcs := newMsg("RequestContextSuccess")
|
||||
setMsg(rcs, "request_context", rc)
|
||||
|
||||
// RequestContextResult (oneof success)
|
||||
rcr := newMsg("RequestContextResult")
|
||||
setMsg(rcr, "success", rcs)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
|
||||
}
|
||||
|
||||
// EncodeExecMcpResult responds with MCP tool result.
|
||||
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
|
||||
textContent := newMsg("McpTextContent")
|
||||
setStr(textContent, "text", content)
|
||||
|
||||
contentItem := newMsg("McpToolResultContentItem")
|
||||
setMsg(contentItem, "text", textContent)
|
||||
|
||||
success := newMsg("McpSuccess")
|
||||
contentField := field(success, "content")
|
||||
contentList := success.Mutable(contentField).List()
|
||||
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
|
||||
setBool(success, "is_error", isError)
|
||||
|
||||
result := newMsg("McpResult")
|
||||
setMsg(result, "success", success)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||
}
|
||||
|
||||
// EncodeExecMcpError responds with MCP error.
|
||||
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
mcpErr := newMsg("McpError")
|
||||
setStr(mcpErr, "error", errMsg)
|
||||
|
||||
result := newMsg("McpResult")
|
||||
setMsg(result, "error", mcpErr)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||
}
|
||||
|
||||
// --- Rejection encoders (mirror handleExecMessage rejections) ---
|
||||
|
||||
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("ReadRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("ReadResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||
rej := newMsg("ShellRejected")
|
||||
setStr(rej, "command", command)
|
||||
setStr(rej, "working_directory", workDir)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("ShellResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("WriteRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("WriteResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("DeleteRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("DeleteResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("LsRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("LsResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
grepErr := newMsg("GrepError")
|
||||
setStr(grepErr, "error", errMsg)
|
||||
result := newMsg("GrepResult")
|
||||
setMsg(result, "error", grepErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
|
||||
fetchErr := newMsg("FetchError")
|
||||
setStr(fetchErr, "url", url)
|
||||
setStr(fetchErr, "error", errMsg)
|
||||
result := newMsg("FetchResult")
|
||||
setMsg(result, "error", fetchErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
|
||||
result := newMsg("DiagnosticsResult")
|
||||
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||
rej := newMsg("ShellRejected")
|
||||
setStr(rej, "command", command)
|
||||
setStr(rej, "working_directory", workDir)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("BackgroundShellSpawnResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
wsErr := newMsg("WriteShellStdinError")
|
||||
setStr(wsErr, "error", errMsg)
|
||||
result := newMsg("WriteShellStdinResult")
|
||||
setMsg(result, "error", wsErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
|
||||
}
|
||||
|
||||
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
|
||||
// Mirrors sendExec() in cursor-fetch.ts.
|
||||
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
|
||||
ecm := newMsg("ExecClientMessage")
|
||||
setUint32(ecm, "id", id)
|
||||
// Force set exec_id even if empty - Cursor requires this field to be set
|
||||
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
|
||||
|
||||
// Debug: check if field exists
|
||||
fd := field(ecm, resultFieldName)
|
||||
if fd == nil {
|
||||
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
|
||||
}
|
||||
|
||||
// Debug: log the actual field being set
|
||||
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
|
||||
|
||||
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "exec_client_message", ecm)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
func listFields(msg *dynamicpb.Message) []string {
|
||||
var names []string
|
||||
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
|
||||
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// --- Utilities ---
|
||||
|
||||
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
|
||||
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
|
||||
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
|
||||
if len(jsonData) == 0 {
|
||||
return nil
|
||||
}
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(jsonData, &v); err != nil {
|
||||
return jsonData // fallback to raw JSON if parsing fails
|
||||
}
|
||||
pbVal, err := structpb.NewValue(v)
|
||||
if err != nil {
|
||||
return jsonData // fallback
|
||||
}
|
||||
b, err := proto.Marshal(pbVal)
|
||||
if err != nil {
|
||||
return jsonData // fallback
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
|
||||
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
|
||||
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
|
||||
val := &structpb.Value{}
|
||||
if err := proto.Unmarshal(data, val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return val.AsInterface(), nil
|
||||
}
|
||||
|
||||
func sha256Sum(data []byte) []byte {
|
||||
h := sha256.Sum256(data)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
var idCounter uint64
|
||||
|
||||
func generateId() string {
|
||||
idCounter++
|
||||
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
|
||||
return hex.EncodeToString(h[:16])
|
||||
}
|
||||
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
@@ -0,0 +1,332 @@
|
||||
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
|
||||
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
|
||||
package proto
|
||||
|
||||
// AgentClientMessage (msg 118) oneof "message"
|
||||
const (
|
||||
ACM_RunRequest = 1 // AgentRunRequest
|
||||
ACM_ExecClientMessage = 2 // ExecClientMessage
|
||||
ACM_KvClientMessage = 3 // KvClientMessage
|
||||
ACM_ConversationAction = 4 // ConversationAction
|
||||
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
|
||||
ACM_InteractionResponse = 6 // InteractionResponse
|
||||
ACM_ClientHeartbeat = 7 // ClientHeartbeat
|
||||
)
|
||||
|
||||
// AgentServerMessage (msg 119) oneof "message"
|
||||
const (
|
||||
ASM_InteractionUpdate = 1 // InteractionUpdate
|
||||
ASM_ExecServerMessage = 2 // ExecServerMessage
|
||||
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
|
||||
ASM_KvServerMessage = 4 // KvServerMessage
|
||||
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
|
||||
ASM_InteractionQuery = 7 // InteractionQuery
|
||||
)
|
||||
|
||||
// AgentRunRequest (msg 91)
|
||||
const (
|
||||
ARR_ConversationState = 1 // ConversationStateStructure
|
||||
ARR_Action = 2 // ConversationAction
|
||||
ARR_ModelDetails = 3 // ModelDetails
|
||||
ARR_McpTools = 4 // McpTools
|
||||
ARR_ConversationId = 5 // string (optional)
|
||||
)
|
||||
|
||||
// ConversationStateStructure (msg 83)
|
||||
const (
|
||||
CSS_RootPromptMessagesJson = 1 // repeated bytes
|
||||
CSS_TurnsOld = 2 // repeated bytes (deprecated)
|
||||
CSS_Todos = 3 // repeated bytes
|
||||
CSS_PendingToolCalls = 4 // repeated string
|
||||
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
|
||||
CSS_PreviousWorkspaceUris = 9 // repeated string
|
||||
CSS_SelfSummaryCount = 17 // uint32
|
||||
CSS_ReadPaths = 18 // repeated string
|
||||
)
|
||||
|
||||
// ConversationAction (msg 54) oneof "action"
|
||||
const (
|
||||
CA_UserMessageAction = 1 // UserMessageAction
|
||||
)
|
||||
|
||||
// UserMessageAction (msg 55)
|
||||
const (
|
||||
UMA_UserMessage = 1 // UserMessage
|
||||
)
|
||||
|
||||
// UserMessage (msg 63)
|
||||
const (
|
||||
UM_Text = 1 // string
|
||||
UM_MessageId = 2 // string
|
||||
UM_SelectedContext = 3 // SelectedContext (optional)
|
||||
)
|
||||
|
||||
// SelectedContext
|
||||
const (
|
||||
SC_SelectedImages = 1 // repeated SelectedImage
|
||||
)
|
||||
|
||||
// SelectedImage
|
||||
const (
|
||||
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
|
||||
SI_Uuid = 2 // string
|
||||
SI_Path = 3 // string
|
||||
SI_MimeType = 7 // string
|
||||
SI_Data = 8 // bytes (oneof dataOrBlobId)
|
||||
)
|
||||
|
||||
// ModelDetails (msg 88)
|
||||
const (
|
||||
MD_ModelId = 1 // string
|
||||
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
|
||||
MD_DisplayModelId = 3 // string
|
||||
MD_DisplayName = 4 // string
|
||||
)
|
||||
|
||||
// McpTools (msg 307)
|
||||
const (
|
||||
MT_McpTools = 1 // repeated McpToolDefinition
|
||||
)
|
||||
|
||||
// McpToolDefinition (msg 306)
|
||||
const (
|
||||
MTD_Name = 1 // string
|
||||
MTD_Description = 2 // string
|
||||
MTD_InputSchema = 3 // bytes
|
||||
MTD_ProviderIdentifier = 4 // string
|
||||
MTD_ToolName = 5 // string
|
||||
)
|
||||
|
||||
// ConversationTurnStructure (msg 70) oneof "turn"
|
||||
const (
|
||||
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
|
||||
)
|
||||
|
||||
// AgentConversationTurnStructure (msg 72)
|
||||
const (
|
||||
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
|
||||
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
|
||||
)
|
||||
|
||||
// ConversationStep (msg 53) oneof "message"
|
||||
const (
|
||||
CS_AssistantMessage = 1 // AssistantMessage
|
||||
)
|
||||
|
||||
// AssistantMessage
|
||||
const (
|
||||
AM_Text = 1 // string
|
||||
)
|
||||
|
||||
// --- Server-side message fields ---
|
||||
|
||||
// InteractionUpdate oneof "message"
|
||||
const (
|
||||
IU_TextDelta = 1 // TextDeltaUpdate
|
||||
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
|
||||
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
|
||||
)
|
||||
|
||||
// TextDeltaUpdate (msg 92)
|
||||
const (
|
||||
TDU_Text = 1 // string
|
||||
)
|
||||
|
||||
// ThinkingDeltaUpdate (msg 97)
|
||||
const (
|
||||
TKD_Text = 1 // string
|
||||
)
|
||||
|
||||
// KvServerMessage (msg 271)
|
||||
const (
|
||||
KSM_Id = 1 // uint32
|
||||
KSM_GetBlobArgs = 2 // GetBlobArgs
|
||||
KSM_SetBlobArgs = 3 // SetBlobArgs
|
||||
)
|
||||
|
||||
// GetBlobArgs (msg 267)
|
||||
const (
|
||||
GBA_BlobId = 1 // bytes
|
||||
)
|
||||
|
||||
// SetBlobArgs (msg 269)
|
||||
const (
|
||||
SBA_BlobId = 1 // bytes
|
||||
SBA_BlobData = 2 // bytes
|
||||
)
|
||||
|
||||
// KvClientMessage (msg 272)
|
||||
const (
|
||||
KCM_Id = 1 // uint32
|
||||
KCM_GetBlobResult = 2 // GetBlobResult
|
||||
KCM_SetBlobResult = 3 // SetBlobResult
|
||||
)
|
||||
|
||||
// GetBlobResult (msg 268)
|
||||
const (
|
||||
GBR_BlobData = 1 // bytes (optional)
|
||||
)
|
||||
|
||||
// ExecServerMessage
|
||||
const (
|
||||
ESM_Id = 1 // uint32
|
||||
ESM_ExecId = 15 // string
|
||||
// oneof message:
|
||||
ESM_ShellArgs = 2 // ShellArgs
|
||||
ESM_WriteArgs = 3 // WriteArgs
|
||||
ESM_DeleteArgs = 4 // DeleteArgs
|
||||
ESM_GrepArgs = 5 // GrepArgs
|
||||
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
|
||||
ESM_LsArgs = 8 // LsArgs
|
||||
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
|
||||
ESM_RequestContextArgs = 10 // RequestContextArgs
|
||||
ESM_McpArgs = 11 // McpArgs
|
||||
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
|
||||
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
|
||||
ESM_FetchArgs = 20 // FetchArgs
|
||||
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
|
||||
)
|
||||
|
||||
// ExecClientMessage
|
||||
const (
|
||||
ECM_Id = 1 // uint32
|
||||
ECM_ExecId = 15 // string
|
||||
// oneof message (mirrors server fields):
|
||||
ECM_ShellResult = 2
|
||||
ECM_WriteResult = 3
|
||||
ECM_DeleteResult = 4
|
||||
ECM_GrepResult = 5
|
||||
ECM_ReadResult = 7
|
||||
ECM_LsResult = 8
|
||||
ECM_DiagnosticsResult = 9
|
||||
ECM_RequestContextResult = 10
|
||||
ECM_McpResult = 11
|
||||
ECM_ShellStream = 14
|
||||
ECM_BackgroundShellSpawnRes = 16
|
||||
ECM_FetchResult = 20
|
||||
ECM_WriteShellStdinResult = 23
|
||||
)
|
||||
|
||||
// McpArgs
|
||||
const (
|
||||
MCA_Name = 1 // string
|
||||
MCA_Args = 2 // map<string, bytes>
|
||||
MCA_ToolCallId = 3 // string
|
||||
MCA_ProviderIdentifier = 4 // string
|
||||
MCA_ToolName = 5 // string
|
||||
)
|
||||
|
||||
// RequestContextResult oneof "result"
|
||||
const (
|
||||
RCR_Success = 1 // RequestContextSuccess
|
||||
RCR_Error = 2 // RequestContextError
|
||||
)
|
||||
|
||||
// RequestContextSuccess (msg 337)
|
||||
const (
|
||||
RCS_RequestContext = 1 // RequestContext
|
||||
)
|
||||
|
||||
// RequestContext
|
||||
const (
|
||||
RC_Rules = 2 // repeated CursorRule
|
||||
RC_Tools = 7 // repeated McpToolDefinition
|
||||
)
|
||||
|
||||
// McpResult oneof "result"
|
||||
const (
|
||||
MCR_Success = 1 // McpSuccess
|
||||
MCR_Error = 2 // McpError
|
||||
MCR_Rejected = 3 // McpRejected
|
||||
)
|
||||
|
||||
// McpSuccess (msg 290)
|
||||
const (
|
||||
MCS_Content = 1 // repeated McpToolResultContentItem
|
||||
MCS_IsError = 2 // bool
|
||||
)
|
||||
|
||||
// McpToolResultContentItem oneof "content"
|
||||
const (
|
||||
MTRCI_Text = 1 // McpTextContent
|
||||
)
|
||||
|
||||
// McpTextContent (msg 287)
|
||||
const (
|
||||
MTC_Text = 1 // string
|
||||
)
|
||||
|
||||
// McpError (msg 291)
|
||||
const (
|
||||
MCE_Error = 1 // string
|
||||
)
|
||||
|
||||
// --- Rejection messages ---
|
||||
|
||||
// ReadRejected: path=1, reason=2
|
||||
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
|
||||
// WriteRejected: path=1, reason=2
|
||||
// DeleteRejected: path=1, reason=2
|
||||
// LsRejected: path=1, reason=2
|
||||
// GrepError: error=1
|
||||
// FetchError: url=1, error=2
|
||||
// WriteShellStdinError: error=1
|
||||
|
||||
// ReadResult oneof: success=1, error=2, rejected=3
|
||||
// ShellResult oneof: success=1 (+ various), rejected=?
|
||||
// The TS code uses specific result field numbers from the oneof:
|
||||
const (
|
||||
RR_Rejected = 3 // ReadResult.rejected
|
||||
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
|
||||
WR_Rejected = 5 // WriteResult.rejected
|
||||
DR_Rejected = 3 // DeleteResult.rejected
|
||||
LR_Rejected = 3 // LsResult.rejected
|
||||
GR_Error = 2 // GrepResult.error
|
||||
FR_Error = 2 // FetchResult.error
|
||||
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
|
||||
WSSR_Error = 2 // WriteShellStdinResult.error
|
||||
)
|
||||
|
||||
// --- Rejection struct fields ---
|
||||
const (
|
||||
REJ_Path = 1
|
||||
REJ_Reason = 2
|
||||
SREJ_Command = 1
|
||||
SREJ_WorkingDir = 2
|
||||
SREJ_Reason = 3
|
||||
SREJ_IsReadonly = 4
|
||||
GERR_Error = 1
|
||||
FERR_Url = 1
|
||||
FERR_Error = 2
|
||||
)
|
||||
|
||||
// ReadArgs
|
||||
const (
|
||||
RA_Path = 1 // string
|
||||
)
|
||||
|
||||
// WriteArgs
|
||||
const (
|
||||
WA_Path = 1 // string
|
||||
)
|
||||
|
||||
// DeleteArgs
|
||||
const (
|
||||
DA_Path = 1 // string
|
||||
)
|
||||
|
||||
// LsArgs
|
||||
const (
|
||||
LA_Path = 1 // string
|
||||
)
|
||||
|
||||
// ShellArgs
|
||||
const (
|
||||
SHA_Command = 1 // string
|
||||
SHA_WorkingDirectory = 2 // string
|
||||
)
|
||||
|
||||
// FetchArgs
|
||||
const (
|
||||
FA_Url = 1 // string
|
||||
)
|
||||
313
internal/auth/cursor/proto/h2stream.go
Normal file
313
internal/auth/cursor/proto/h2stream.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultInitialWindowSize = 65535 // HTTP/2 default
|
||||
maxFramePayload = 16384 // HTTP/2 default max frame size
|
||||
)
|
||||
|
||||
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
||||
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
||||
type H2Stream struct {
|
||||
framer *http2.Framer
|
||||
conn net.Conn
|
||||
streamID uint32
|
||||
mu sync.Mutex
|
||||
id string // unique identifier for debugging
|
||||
frameNum int64 // sequential frame counter for debugging
|
||||
|
||||
dataCh chan []byte
|
||||
doneCh chan struct{}
|
||||
err error
|
||||
|
||||
// Send-side flow control
|
||||
sendWindow int32 // available bytes we can send on this stream
|
||||
connWindow int32 // available bytes on the connection level
|
||||
windowCond *sync.Cond // signaled when window is updated
|
||||
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||
}
|
||||
|
||||
// ID returns the unique identifier for this stream (for logging).
|
||||
func (s *H2Stream) ID() string { return s.id }
|
||||
|
||||
// FrameNum returns the current frame number for debugging.
|
||||
func (s *H2Stream) FrameNum() int64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.frameNum
|
||||
}
|
||||
|
||||
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
|
||||
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
|
||||
NextProtos: []string{"h2"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
|
||||
}
|
||||
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: server did not negotiate h2")
|
||||
}
|
||||
|
||||
framer := http2.NewFramer(tlsConn, tlsConn)
|
||||
|
||||
// Client connection preface
|
||||
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
||||
}
|
||||
|
||||
// Send initial SETTINGS (tell server how much WE can receive)
|
||||
if err := framer.WriteSettings(
|
||||
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
||||
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
||||
); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
||||
}
|
||||
|
||||
// Connection-level window update (for receiving)
|
||||
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
||||
}
|
||||
|
||||
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
||||
// Track server's initial window size (how much WE can send)
|
||||
serverInitialWindowSize := int32(defaultInitialWindowSize)
|
||||
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
|
||||
for i := 0; i < 10; i++ {
|
||||
f, err := framer.ReadFrame()
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
|
||||
}
|
||||
switch sf := f.(type) {
|
||||
case *http2.SettingsFrame:
|
||||
if !sf.IsAck() {
|
||||
sf.ForeachSetting(func(s http2.Setting) error {
|
||||
if s.ID == http2.SettingInitialWindowSize {
|
||||
serverInitialWindowSize = int32(s.Val)
|
||||
log.Debugf("h2: server initial window size: %d", s.Val)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
framer.WriteSettingsAck()
|
||||
} else {
|
||||
goto handshakeDone
|
||||
}
|
||||
case *http2.WindowUpdateFrame:
|
||||
if sf.StreamID == 0 {
|
||||
connWindowSize += int32(sf.Increment)
|
||||
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
|
||||
}
|
||||
default:
|
||||
// unexpected but continue
|
||||
}
|
||||
}
|
||||
handshakeDone:
|
||||
|
||||
// Build HEADERS
|
||||
streamID := uint32(1)
|
||||
var hdrBuf []byte
|
||||
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
|
||||
if p, ok := headers[":path"]; ok {
|
||||
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
|
||||
}
|
||||
for k, v := range headers {
|
||||
if len(k) > 0 && k[0] == ':' {
|
||||
continue
|
||||
}
|
||||
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
}
|
||||
|
||||
if err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: streamID,
|
||||
BlockFragment: hdrBuf,
|
||||
EndStream: false,
|
||||
EndHeaders: true,
|
||||
}); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: headers write failed: %w", err)
|
||||
}
|
||||
|
||||
s := &H2Stream{
|
||||
framer: framer,
|
||||
conn: tlsConn,
|
||||
streamID: streamID,
|
||||
dataCh: make(chan []byte, 256),
|
||||
doneCh: make(chan struct{}),
|
||||
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||
frameNum: 0,
|
||||
sendWindow: serverInitialWindowSize,
|
||||
connWindow: connWindowSize,
|
||||
}
|
||||
s.windowCond = sync.NewCond(&s.windowMu)
|
||||
go s.readLoop()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Write sends a DATA frame on the stream, respecting flow control.
|
||||
func (s *H2Stream) Write(data []byte) error {
|
||||
for len(data) > 0 {
|
||||
chunk := data
|
||||
if len(chunk) > maxFramePayload {
|
||||
chunk = data[:maxFramePayload]
|
||||
}
|
||||
|
||||
// Wait for flow control window
|
||||
s.windowMu.Lock()
|
||||
for s.sendWindow <= 0 || s.connWindow <= 0 {
|
||||
s.windowCond.Wait()
|
||||
}
|
||||
// Limit chunk to available window
|
||||
allowed := int(s.sendWindow)
|
||||
if int(s.connWindow) < allowed {
|
||||
allowed = int(s.connWindow)
|
||||
}
|
||||
if len(chunk) > allowed {
|
||||
chunk = chunk[:allowed]
|
||||
}
|
||||
s.sendWindow -= int32(len(chunk))
|
||||
s.connWindow -= int32(len(chunk))
|
||||
s.windowMu.Unlock()
|
||||
|
||||
s.mu.Lock()
|
||||
err := s.framer.WriteData(s.streamID, false, chunk)
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = data[len(chunk):]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns the channel of received data chunks.
|
||||
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
|
||||
|
||||
// Done returns a channel closed when the stream ends.
|
||||
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
|
||||
|
||||
// Err returns the error (if any) that caused the stream to close.
|
||||
// Returns nil for a clean shutdown (EOF / StreamEnded).
|
||||
func (s *H2Stream) Err() error { return s.err }
|
||||
|
||||
// Close tears down the connection.
|
||||
func (s *H2Stream) Close() {
|
||||
s.conn.Close()
|
||||
// Unblock any writers waiting on flow control
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
|
||||
func (s *H2Stream) readLoop() {
|
||||
defer close(s.doneCh)
|
||||
defer close(s.dataCh)
|
||||
|
||||
for {
|
||||
f, err := s.framer.ReadFrame()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
s.err = err
|
||||
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Increment frame counter
|
||||
s.mu.Lock()
|
||||
s.frameNum++
|
||||
s.mu.Unlock()
|
||||
|
||||
switch frame := f.(type) {
|
||||
case *http2.DataFrame:
|
||||
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
||||
cp := make([]byte, len(frame.Data()))
|
||||
copy(cp, frame.Data())
|
||||
s.dataCh <- cp
|
||||
|
||||
// Flow control: send WINDOW_UPDATE for received data
|
||||
s.mu.Lock()
|
||||
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
||||
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
||||
s.mu.Unlock()
|
||||
}
|
||||
if frame.StreamEnded() {
|
||||
return
|
||||
}
|
||||
|
||||
case *http2.HeadersFrame:
|
||||
if frame.StreamEnded() {
|
||||
return
|
||||
}
|
||||
|
||||
case *http2.RSTStreamFrame:
|
||||
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
||||
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
|
||||
return
|
||||
|
||||
case *http2.GoAwayFrame:
|
||||
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
||||
return
|
||||
|
||||
case *http2.PingFrame:
|
||||
if !frame.IsAck() {
|
||||
s.mu.Lock()
|
||||
s.framer.WritePing(true, frame.Data)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
case *http2.SettingsFrame:
|
||||
if !frame.IsAck() {
|
||||
// Check for window size changes
|
||||
frame.ForeachSetting(func(setting http2.Setting) error {
|
||||
if setting.ID == http2.SettingInitialWindowSize {
|
||||
s.windowMu.Lock()
|
||||
delta := int32(setting.Val) - s.sendWindow
|
||||
s.sendWindow += delta
|
||||
s.windowMu.Unlock()
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.mu.Lock()
|
||||
s.framer.WriteSettingsAck()
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
case *http2.WindowUpdateFrame:
|
||||
// Update send-side flow control window
|
||||
s.windowMu.Lock()
|
||||
if frame.StreamID == 0 {
|
||||
s.connWindow += int32(frame.Increment)
|
||||
} else if frame.StreamID == s.streamID {
|
||||
s.sendWindow += int32(frame.Increment)
|
||||
}
|
||||
s.windowMu.Unlock()
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sliceWriter struct{ buf *[]byte }
|
||||
|
||||
func (w *sliceWriter) Write(p []byte) (int, error) {
|
||||
*w.buf = append(*w.buf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -24,6 +24,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
sdkAuth.NewCodeBuddyAuthenticator(),
|
||||
sdkAuth.NewCursorAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
37
internal/cmd/cursor_login.go
Normal file
37
internal/cmd/cursor_login.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
|
||||
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("Cursor authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
log.Infof("Authentication saved to %s", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
log.Infof("Authenticated as %s", record.Label)
|
||||
}
|
||||
log.Info("Cursor authentication successful!")
|
||||
}
|
||||
@@ -211,6 +211,10 @@ type QuotaExceeded struct {
|
||||
|
||||
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
||||
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
||||
|
||||
// AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once
|
||||
// on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"].
|
||||
AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"`
|
||||
}
|
||||
|
||||
// RoutingConfig configures how credentials are selected for requests.
|
||||
@@ -257,8 +261,8 @@ type AmpCode struct {
|
||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||
|
||||
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
||||
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
||||
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||
// is used for the upstream Amp request.
|
||||
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||
|
||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||
@@ -380,6 +384,11 @@ type ClaudeKey struct {
|
||||
|
||||
// Cloak configures request cloaking for non-Claude-Code clients.
|
||||
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
||||
|
||||
// ExperimentalCCHSigning enables opt-in final-body cch signing for cloaked
|
||||
// Claude /v1/messages requests. It is disabled by default so upstream seed
|
||||
// changes do not alter the proxy's legacy behavior.
|
||||
ExperimentalCCHSigning bool `yaml:"experimental-cch-signing,omitempty" json:"experimental-cch-signing,omitempty"`
|
||||
}
|
||||
|
||||
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
@@ -41,15 +42,17 @@ type RequestLogger interface {
|
||||
// - statusCode: The response status code
|
||||
// - responseHeaders: The response headers
|
||||
// - response: The raw response data
|
||||
// - websocketTimeline: Optional downstream websocket event timeline
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - apiWebsocketTimeline: Optional upstream websocket event timeline
|
||||
// - requestID: Optional request ID for log file naming
|
||||
// - requestTimestamp: When the request was received
|
||||
// - apiResponseTimestamp: When the API response was received
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||
|
||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||
//
|
||||
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
WriteAPIResponse(apiResponse []byte) error
|
||||
|
||||
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
|
||||
// This should be called when upstream communication happened over websocket.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
|
||||
|
||||
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
||||
//
|
||||
// Parameters:
|
||||
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
}
|
||||
|
||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
if !l.enabled && !force {
|
||||
return nil
|
||||
}
|
||||
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
requestHeaders,
|
||||
body,
|
||||
requestBodyPath,
|
||||
websocketTimeline,
|
||||
apiRequest,
|
||||
apiResponse,
|
||||
apiWebsocketTimeline,
|
||||
apiResponseErrors,
|
||||
statusCode,
|
||||
responseHeaders,
|
||||
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
requestHeaders map[string][]string,
|
||||
requestBody []byte,
|
||||
requestBodyPath string,
|
||||
websocketTimeline []byte,
|
||||
apiRequest []byte,
|
||||
apiResponse []byte,
|
||||
apiWebsocketTimeline []byte,
|
||||
apiResponseErrors []*interfaces.ErrorMessage,
|
||||
statusCode int,
|
||||
responseHeaders map[string][]string,
|
||||
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
if requestTimestamp.IsZero() {
|
||||
requestTimestamp = time.Now()
|
||||
}
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
|
||||
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
|
||||
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
||||
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if isWebsocketTranscript {
|
||||
// Intentionally omit the generic downstream HTTP response section for websocket
|
||||
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
|
||||
// and appending a one-off upgrade response snapshot would dilute that transcript.
|
||||
return nil
|
||||
}
|
||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||
}
|
||||
|
||||
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
|
||||
body []byte,
|
||||
bodyPath string,
|
||||
timestamp time.Time,
|
||||
downstreamTransport string,
|
||||
upstreamTransport string,
|
||||
includeBody bool,
|
||||
) error {
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if strings.TrimSpace(downstreamTransport) != "" {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(upstreamTransport) != "" {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if !includeBody {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
bodyTrailingNewlines := 1
|
||||
if bodyPath != "" {
|
||||
bodyFile, errOpen := os.Open(bodyPath)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
}
|
||||
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
|
||||
tracker := &trailingNewlineTrackingWriter{writer: w}
|
||||
written, errCopy := io.Copy(tracker, bodyFile)
|
||||
if errCopy != nil {
|
||||
_ = bodyFile.Close()
|
||||
return errCopy
|
||||
}
|
||||
if written > 0 {
|
||||
bodyTrailingNewlines = tracker.trailingNewlines
|
||||
}
|
||||
if errClose := bodyFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request body temp file")
|
||||
}
|
||||
} else if _, errWrite := w.Write(body); errWrite != nil {
|
||||
return errWrite
|
||||
} else if len(body) > 0 {
|
||||
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func countTrailingNewlinesBytes(payload []byte) int {
|
||||
count := 0
|
||||
for i := len(payload) - 1; i >= 0; i-- {
|
||||
if payload[i] != '\n' {
|
||||
break
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
|
||||
missingNewlines := 3 - trailingNewlines
|
||||
if missingNewlines <= 0 {
|
||||
return nil
|
||||
}
|
||||
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
|
||||
return errWrite
|
||||
}
|
||||
|
||||
type trailingNewlineTrackingWriter struct {
|
||||
writer io.Writer
|
||||
trailingNewlines int
|
||||
}
|
||||
|
||||
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
|
||||
written, errWrite := t.writer.Write(payload)
|
||||
if written > 0 {
|
||||
writtenPayload := payload[:written]
|
||||
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
|
||||
if trailingNewlines == len(writtenPayload) {
|
||||
t.trailingNewlines += trailingNewlines
|
||||
} else {
|
||||
t.trailingNewlines = trailingNewlines
|
||||
}
|
||||
}
|
||||
return written, errWrite
|
||||
}
|
||||
|
||||
func hasSectionPayload(payload []byte) bool {
|
||||
return len(bytes.TrimSpace(payload)) > 0
|
||||
}
|
||||
|
||||
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
|
||||
if hasSectionPayload(websocketTimeline) {
|
||||
return "websocket"
|
||||
}
|
||||
for key, values := range headers {
|
||||
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
|
||||
for _, value := range values {
|
||||
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
|
||||
return "websocket"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
|
||||
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
|
||||
hasWS := hasSectionPayload(apiWebsocketTimeline)
|
||||
switch {
|
||||
case hasHTTP && hasWS:
|
||||
return "websocket+http"
|
||||
case hasWS:
|
||||
return "websocket"
|
||||
case hasHTTP:
|
||||
return "http"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if !bytes.HasSuffix(payload, []byte("\n")) {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||
return errWrite
|
||||
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
trailingNewlines := 1
|
||||
if apiResponseErrors[i].Error != nil {
|
||||
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
|
||||
errText := apiResponseErrors[i].Error.Error()
|
||||
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errText != "" {
|
||||
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
var bufferedReader *bufio.Reader
|
||||
if responseReader != nil {
|
||||
bufferedReader = bufio.NewReader(responseReader)
|
||||
}
|
||||
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if responseReader != nil {
|
||||
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
|
||||
if bufferedReader != nil {
|
||||
if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
|
||||
return errCopy
|
||||
}
|
||||
}
|
||||
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
|
||||
if reader == nil {
|
||||
return false
|
||||
}
|
||||
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
|
||||
return true
|
||||
}
|
||||
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// formatLogContent creates the complete log content for non-streaming requests.
|
||||
//
|
||||
// Parameters:
|
||||
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
||||
// - method: The HTTP method
|
||||
// - headers: The request headers
|
||||
// - body: The request body
|
||||
// - websocketTimeline: The downstream websocket event timeline
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - response: The raw response data
|
||||
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
||||
//
|
||||
// Returns:
|
||||
// - string: The formatted log content
|
||||
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
||||
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
||||
var content strings.Builder
|
||||
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
|
||||
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||
|
||||
// Request info
|
||||
content.WriteString(l.formatRequestInfo(url, method, headers, body))
|
||||
content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
|
||||
|
||||
if len(websocketTimeline) > 0 {
|
||||
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
|
||||
content.Write(websocketTimeline)
|
||||
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
|
||||
content.Write(websocketTimeline)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(apiWebsocketTimeline) > 0 {
|
||||
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
|
||||
content.Write(apiWebsocketTimeline)
|
||||
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
|
||||
content.Write(apiWebsocketTimeline)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(apiRequest) > 0 {
|
||||
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
|
||||
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if isWebsocketTranscript {
|
||||
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
|
||||
// timeline sections instead of a generic downstream HTTP response block.
|
||||
return content.String()
|
||||
}
|
||||
|
||||
// Response section
|
||||
content.WriteString("=== RESPONSE ===\n")
|
||||
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
|
||||
//
|
||||
// Returns:
|
||||
// - string: The formatted request information
|
||||
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
|
||||
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
|
||||
var content strings.Builder
|
||||
|
||||
content.WriteString("=== REQUEST INFO ===\n")
|
||||
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
||||
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
||||
if strings.TrimSpace(downstreamTransport) != "" {
|
||||
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
|
||||
}
|
||||
if strings.TrimSpace(upstreamTransport) != "" {
|
||||
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
|
||||
}
|
||||
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
content.WriteString("\n")
|
||||
|
||||
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
||||
}
|
||||
content.WriteString("\n")
|
||||
|
||||
if !includeBody {
|
||||
return content.String()
|
||||
}
|
||||
|
||||
content.WriteString("=== REQUEST BODY ===\n")
|
||||
content.Write(body)
|
||||
content.WriteString("\n\n")
|
||||
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
|
||||
// apiResponse stores the upstream API response data.
|
||||
apiResponse []byte
|
||||
|
||||
// apiWebsocketTimeline stores the upstream websocket event timeline.
|
||||
apiWebsocketTimeline []byte
|
||||
|
||||
// apiResponseTimestamp captures when the API response was received.
|
||||
apiResponseTimestamp time.Time
|
||||
}
|
||||
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil (buffering cannot fail)
|
||||
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||
if len(apiWebsocketTimeline) == 0 {
|
||||
return nil
|
||||
}
|
||||
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||
if !timestamp.IsZero() {
|
||||
w.apiResponseTimestamp = timestamp
|
||||
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||
|
||||
// Close finalizes the log file and cleans up resources.
|
||||
// It writes all buffered data to the file in the correct order:
|
||||
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||
// API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if closing fails, nil otherwise
|
||||
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
||||
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
|
||||
//
|
||||
// Parameters:
|
||||
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil
|
||||
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
||||
|
||||
// Close is a no-op implementation that does nothing and always returns nil.
|
||||
|
||||
@@ -93,6 +93,30 @@ func GetAntigravityModels() []*ModelInfo {
|
||||
func GetCodeBuddyModels() []*ModelInfo {
|
||||
now := int64(1748044800) // 2025-05-24
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "auto",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "Auto",
|
||||
Description: "Automatic model selection via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "glm-5.0-turbo",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "GLM-5.0 Turbo",
|
||||
Description: "GLM-5.0 Turbo via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "glm-5.0",
|
||||
Object: "model",
|
||||
@@ -118,13 +142,13 @@ func GetCodeBuddyModels() []*ModelInfo {
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "minimax-m2.5",
|
||||
ID: "minimax-m2.7",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "MiniMax M2.5",
|
||||
Description: "MiniMax M2.5 via CodeBuddy",
|
||||
DisplayName: "MiniMax M2.7",
|
||||
Description: "MiniMax M2.7 via CodeBuddy",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
@@ -141,6 +165,19 @@ func GetCodeBuddyModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2-thinking",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "Kimi K2 Thinking",
|
||||
Description: "Kimi K2 Thinking via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "deepseek-v3-2-volc",
|
||||
Object: "model",
|
||||
@@ -148,24 +185,11 @@ func GetCodeBuddyModels() []*ModelInfo {
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "DeepSeek V3.2 (Volc)",
|
||||
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)",
|
||||
Description: "DeepSeek V3.2 via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "hunyuan-2.0-thinking",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "Hunyuan 2.0 Thinking",
|
||||
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,11 +255,25 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetAntigravityModels()
|
||||
case "codebuddy":
|
||||
return GetCodeBuddyModels()
|
||||
case "cursor":
|
||||
return GetCursorModels()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetCursorModels returns the fallback Cursor model definitions.
|
||||
func GetCursorModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
|
||||
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
|
||||
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
|
||||
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
}
|
||||
}
|
||||
|
||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||
// Returns nil if no matching model is found.
|
||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
@@ -260,6 +298,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetKiloModels(),
|
||||
GetAmazonQModels(),
|
||||
GetCodeBuddyModels(),
|
||||
GetCursorModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
@@ -462,6 +501,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.4-mini",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.4 mini",
|
||||
Description: "OpenAI GPT-5.4 mini via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
Object: "model",
|
||||
@@ -556,6 +608,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-preview",
|
||||
@@ -567,6 +620,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
@@ -576,8 +630,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Gemini 3.1 Pro (Preview)",
|
||||
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
||||
ContextLength: 1048576,
|
||||
ContextLength: 173000,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
@@ -587,8 +642,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Gemini 3 Flash (Preview)",
|
||||
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
|
||||
ContextLength: 1048576,
|
||||
ContextLength: 173000,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "grok-code-fast-1",
|
||||
|
||||
29
internal/registry/model_definitions_test.go
Normal file
29
internal/registry/model_definitions_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package registry
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGitHubCopilotGeminiModelsAreChatOnly(t *testing.T) {
|
||||
models := GetGitHubCopilotModels()
|
||||
required := map[string]bool{
|
||||
"gemini-2.5-pro": false,
|
||||
"gemini-3-pro-preview": false,
|
||||
"gemini-3.1-pro-preview": false,
|
||||
"gemini-3-flash-preview": false,
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if _, ok := required[model.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
required[model.ID] = true
|
||||
if len(model.SupportedEndpoints) != 1 || model.SupportedEndpoints[0] != "/chat/completions" {
|
||||
t.Fatalf("model %q supported endpoints = %v, want [/chat/completions]", model.ID, model.SupportedEndpoints)
|
||||
}
|
||||
}
|
||||
|
||||
for modelID, found := range required {
|
||||
if !found {
|
||||
t.Fatalf("expected GitHub Copilot model %q in definitions", modelID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -46,8 +48,16 @@ func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Man
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio).
|
||||
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||
// PrepareRequest prepares the HTTP request for execution.
|
||||
func (e *AIStudioExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -66,6 +76,9 @@ func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, fmt.Errorf("aistudio executor: missing auth")
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
|
||||
return nil, fmt.Errorf("aistudio executor: request URL is empty")
|
||||
}
|
||||
@@ -115,8 +128,8 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
translatedReq, body, err := e.translateRequest(req, opts, false)
|
||||
if err != nil {
|
||||
@@ -130,6 +143,11 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: body.payload,
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -137,7 +155,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
@@ -151,17 +169,17 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
|
||||
wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||
if len(wsResp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
|
||||
}
|
||||
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
||||
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
reporter.Publish(ctx, helps.ParseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
|
||||
@@ -174,8 +192,8 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
translatedReq, body, err := e.translateRequest(req, opts, true)
|
||||
if err != nil {
|
||||
@@ -189,13 +207,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: body.payload,
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
@@ -208,24 +231,24 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
})
|
||||
wsStream, err := e.relay.Stream(ctx, authID, wsReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
firstEvent, ok := <-wsStream
|
||||
if !ok {
|
||||
err = fmt.Errorf("wsrelay: stream closed before start")
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK {
|
||||
metadataLogged := false
|
||||
if firstEvent.Status > 0 {
|
||||
recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
|
||||
metadataLogged = true
|
||||
}
|
||||
var body bytes.Buffer
|
||||
if len(firstEvent.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
|
||||
body.Write(firstEvent.Payload)
|
||||
}
|
||||
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -233,18 +256,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
for event := range wsStream {
|
||||
if event.Err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
if body.Len() == 0 {
|
||||
body.WriteString(event.Err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
if !metadataLogged && event.Status > 0 {
|
||||
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
body.Write(event.Payload)
|
||||
}
|
||||
if event.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -260,23 +283,23 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged := false
|
||||
processEvent := func(event wsrelay.StreamEvent) bool {
|
||||
if event.Err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||
return false
|
||||
}
|
||||
switch event.Type {
|
||||
case wsrelay.MessageTypeStreamStart:
|
||||
if !metadataLogged && event.Status > 0 {
|
||||
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||
metadataLogged = true
|
||||
}
|
||||
case wsrelay.MessageTypeStreamChunk:
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
filtered := FilterSSEUsageMetadata(event.Payload)
|
||||
if detail, ok := parseGeminiStreamUsage(filtered); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
filtered := helps.FilterSSEUsageMetadata(event.Payload)
|
||||
if detail, ok := helps.ParseGeminiStreamUsage(filtered); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||
for i := range lines {
|
||||
@@ -288,21 +311,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
return false
|
||||
case wsrelay.MessageTypeHTTPResp:
|
||||
if !metadataLogged && event.Status > 0 {
|
||||
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(event.Payload))
|
||||
reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
|
||||
return false
|
||||
case wsrelay.MessageTypeError:
|
||||
recordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||
return false
|
||||
}
|
||||
@@ -345,7 +368,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
@@ -358,12 +381,12 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
})
|
||||
resp, err := e.relay.NonStream(ctx, authID, wsReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||
if len(resp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, resp.Body)
|
||||
}
|
||||
if resp.Status < 200 || resp.Status >= 300 {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
||||
@@ -404,8 +427,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
return nil, translatedPayload{}, err
|
||||
}
|
||||
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -47,12 +48,41 @@ const (
|
||||
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
antigravityCreditsRetryTTL = 5 * time.Hour
|
||||
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
||||
)
|
||||
|
||||
type antigravity429Category string
|
||||
|
||||
const (
|
||||
antigravity429Unknown antigravity429Category = "unknown"
|
||||
antigravity429RateLimited antigravity429Category = "rate_limited"
|
||||
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
|
||||
)
|
||||
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
antigravityCreditsExhaustedByAuth sync.Map
|
||||
antigravityPreferCreditsByModel sync.Map
|
||||
antigravityQuotaExhaustedKeywords = []string{
|
||||
"quota_exhausted",
|
||||
"quota exhausted",
|
||||
}
|
||||
antigravityCreditsExhaustedKeywords = []string{
|
||||
"google_one_ai",
|
||||
"insufficient credit",
|
||||
"insufficient credits",
|
||||
"not enough credit",
|
||||
"not enough credits",
|
||||
"credit exhausted",
|
||||
"credits exhausted",
|
||||
"credit balance",
|
||||
"minimumcreditamountforusage",
|
||||
"minimum credit amount for usage",
|
||||
"minimum credit",
|
||||
"resource has been exhausted",
|
||||
}
|
||||
)
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
@@ -113,7 +143,7 @@ func initAntigravityTransport() {
|
||||
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
antigravityTransportOnce.Do(initAntigravityTransport)
|
||||
|
||||
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
client := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
// If no transport is set, use the shared HTTP/1.1 transport.
|
||||
if client.Transport == nil {
|
||||
client.Transport = antigravityTransport
|
||||
@@ -183,6 +213,231 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func injectEnabledCreditTypes(payload []byte) []byte {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return nil
|
||||
}
|
||||
updated, err := sjson.SetRawBytes(payload, "enabledCreditTypes", []byte(`["GOOGLE_ONE_AI"]`))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func classifyAntigravity429(body []byte) antigravity429Category {
|
||||
if len(body) == 0 {
|
||||
return antigravity429Unknown
|
||||
}
|
||||
lowerBody := strings.ToLower(string(body))
|
||||
for _, keyword := range antigravityQuotaExhaustedKeywords {
|
||||
if strings.Contains(lowerBody, keyword) {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
}
|
||||
status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String())
|
||||
if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") {
|
||||
return antigravity429Unknown
|
||||
}
|
||||
details := gjson.GetBytes(body, "error.details")
|
||||
if !details.Exists() || !details.IsArray() {
|
||||
return antigravity429Unknown
|
||||
}
|
||||
for _, detail := range details.Array() {
|
||||
if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" {
|
||||
continue
|
||||
}
|
||||
reason := strings.TrimSpace(detail.Get("reason").String())
|
||||
if strings.EqualFold(reason, "QUOTA_EXHAUSTED") {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
if strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED") {
|
||||
return antigravity429RateLimited
|
||||
}
|
||||
}
|
||||
return antigravity429Unknown
|
||||
}
|
||||
|
||||
func antigravityCreditsRetryEnabled(cfg *config.Config) bool {
|
||||
return cfg != nil && cfg.QuotaExceeded.AntigravityCredits
|
||||
}
|
||||
|
||||
func antigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) bool {
|
||||
if auth == nil || strings.TrimSpace(auth.ID) == "" {
|
||||
return false
|
||||
}
|
||||
value, ok := antigravityCreditsExhaustedByAuth.Load(auth.ID)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
until, ok := value.(time.Time)
|
||||
if !ok || until.IsZero() {
|
||||
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
|
||||
return false
|
||||
}
|
||||
if !until.After(now) {
|
||||
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func markAntigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) {
|
||||
if auth == nil || strings.TrimSpace(auth.ID) == "" {
|
||||
return
|
||||
}
|
||||
antigravityCreditsExhaustedByAuth.Store(auth.ID, now.Add(antigravityCreditsRetryTTL))
|
||||
}
|
||||
|
||||
func clearAntigravityCreditsExhausted(auth *cliproxyauth.Auth) {
|
||||
if auth == nil || strings.TrimSpace(auth.ID) == "" {
|
||||
return
|
||||
}
|
||||
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
|
||||
}
|
||||
|
||||
func antigravityPreferCreditsKey(auth *cliproxyauth.Auth, modelName string) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
authID := strings.TrimSpace(auth.ID)
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if authID == "" || modelName == "" {
|
||||
return ""
|
||||
}
|
||||
return authID + "|" + modelName
|
||||
}
|
||||
|
||||
func antigravityShouldPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time) bool {
|
||||
key := antigravityPreferCreditsKey(auth, modelName)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
value, ok := antigravityPreferCreditsByModel.Load(key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
until, ok := value.(time.Time)
|
||||
if !ok || until.IsZero() {
|
||||
antigravityPreferCreditsByModel.Delete(key)
|
||||
return false
|
||||
}
|
||||
if !until.After(now) {
|
||||
antigravityPreferCreditsByModel.Delete(key)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func markAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time, retryAfter *time.Duration) {
|
||||
key := antigravityPreferCreditsKey(auth, modelName)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
until := now.Add(antigravityCreditsRetryTTL)
|
||||
if retryAfter != nil && *retryAfter > 0 {
|
||||
until = now.Add(*retryAfter)
|
||||
}
|
||||
antigravityPreferCreditsByModel.Store(key, until)
|
||||
}
|
||||
|
||||
func clearAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string) {
|
||||
key := antigravityPreferCreditsKey(auth, modelName)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
antigravityPreferCreditsByModel.Delete(key)
|
||||
}
|
||||
|
||||
func shouldMarkAntigravityCreditsExhausted(statusCode int, body []byte, reqErr error) bool {
|
||||
if reqErr != nil || statusCode == 0 {
|
||||
return false
|
||||
}
|
||||
if statusCode >= http.StatusInternalServerError || statusCode == http.StatusRequestTimeout {
|
||||
return false
|
||||
}
|
||||
lowerBody := strings.ToLower(string(body))
|
||||
for _, keyword := range antigravityCreditsExhaustedKeywords {
|
||||
if strings.Contains(lowerBody, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newAntigravityStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) attemptCreditsFallback(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
httpClient *http.Client,
|
||||
token string,
|
||||
modelName string,
|
||||
payload []byte,
|
||||
stream bool,
|
||||
alt string,
|
||||
baseURL string,
|
||||
originalBody []byte,
|
||||
) (*http.Response, bool) {
|
||||
if !antigravityCreditsRetryEnabled(e.cfg) {
|
||||
return nil, false
|
||||
}
|
||||
if classifyAntigravity429(originalBody) != antigravity429QuotaExhausted {
|
||||
return nil, false
|
||||
}
|
||||
now := time.Now()
|
||||
if antigravityCreditsExhausted(auth, now) {
|
||||
return nil, false
|
||||
}
|
||||
creditsPayload := injectEnabledCreditTypes(payload)
|
||||
if len(creditsPayload) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, modelName, creditsPayload, stream, alt, baseURL)
|
||||
if errReq != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errReq)
|
||||
return nil, true
|
||||
}
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return nil, true
|
||||
}
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
retryAfter, _ := parseRetryDelay(originalBody)
|
||||
markAntigravityPreferCredits(auth, modelName, now, retryAfter)
|
||||
clearAntigravityCreditsExhausted(auth)
|
||||
return httpResp, true
|
||||
}
|
||||
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close credits fallback response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return nil, true
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
|
||||
clearAntigravityPreferCredits(auth, modelName)
|
||||
markAntigravityCreditsExhausted(auth, now)
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
@@ -203,8 +458,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
@@ -222,8 +477,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -237,7 +492,15 @@ attemptLoop:
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
|
||||
requestPayload := translated
|
||||
usedCreditsDirect := false
|
||||
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
|
||||
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
|
||||
requestPayload = creditsPayload
|
||||
usedCreditsDirect = true
|
||||
}
|
||||
}
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, false, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return resp, err
|
||||
@@ -245,7 +508,7 @@ attemptLoop:
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return resp, errDo
|
||||
}
|
||||
@@ -260,20 +523,50 @@ attemptLoop:
|
||||
return resp, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if usedCreditsDirect {
|
||||
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
|
||||
clearAntigravityPreferCredits(auth, baseModel)
|
||||
markAntigravityCreditsExhausted(auth, time.Now())
|
||||
}
|
||||
} else {
|
||||
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, false, opts.Alt, baseURL, bodyBytes)
|
||||
if creditsResp != nil {
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, creditsResp.StatusCode, creditsResp.Header.Clone())
|
||||
creditsBody, errCreditsRead := io.ReadAll(creditsResp.Body)
|
||||
if errClose := creditsResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close credits success response body error: %v", errClose)
|
||||
}
|
||||
if errCreditsRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errCreditsRead)
|
||||
err = errCreditsRead
|
||||
return resp, err
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, creditsBody)
|
||||
reporter.Publish(ctx, helps.ParseAntigravityUsage(creditsBody))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, creditsBody, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: creditsResp.Header.Clone()}
|
||||
reporter.EnsurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
@@ -295,33 +588,21 @@ attemptLoop:
|
||||
continue attemptLoop
|
||||
}
|
||||
}
|
||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
reporter.EnsurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
if lastStatus == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
err = newAntigravityStatusErr(lastStatus, lastBody)
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
@@ -345,8 +626,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
@@ -364,8 +645,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -379,7 +660,15 @@ attemptLoop:
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
||||
requestPayload := translated
|
||||
usedCreditsDirect := false
|
||||
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
|
||||
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
|
||||
requestPayload = creditsPayload
|
||||
usedCreditsDirect = true
|
||||
}
|
||||
}
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return resp, err
|
||||
@@ -387,7 +676,7 @@ attemptLoop:
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return resp, errDo
|
||||
}
|
||||
@@ -401,14 +690,14 @@ attemptLoop:
|
||||
err = errDo
|
||||
return resp, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||
err = errRead
|
||||
return resp, err
|
||||
@@ -427,7 +716,24 @@ attemptLoop:
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if usedCreditsDirect {
|
||||
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
|
||||
clearAntigravityPreferCredits(auth, baseModel)
|
||||
markAntigravityCreditsExhausted(auth, time.Now())
|
||||
}
|
||||
} else {
|
||||
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
|
||||
if creditsResp != nil {
|
||||
httpResp = creditsResp
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
goto streamSuccessClaudeNonStream
|
||||
}
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
@@ -449,16 +755,11 @@ attemptLoop:
|
||||
continue attemptLoop
|
||||
}
|
||||
}
|
||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
streamSuccessClaudeNonStream:
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
@@ -471,29 +772,29 @@ attemptLoop:
|
||||
scanner.Buffer(nil, streamScannerBuffer)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
line = helps.FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
payload := helps.JSONPayload(line)
|
||||
if payload == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
reporter.EnsurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
|
||||
@@ -509,24 +810,18 @@ attemptLoop:
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
reporter.Publish(ctx, helps.ParseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
reporter.EnsurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
if lastStatus == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
err = newAntigravityStatusErr(lastStatus, lastBody)
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
@@ -748,8 +1043,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
@@ -767,8 +1062,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -782,14 +1077,22 @@ attemptLoop:
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
||||
requestPayload := translated
|
||||
usedCreditsDirect := false
|
||||
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
|
||||
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
|
||||
requestPayload = creditsPayload
|
||||
usedCreditsDirect = true
|
||||
}
|
||||
}
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return nil, err
|
||||
}
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return nil, errDo
|
||||
}
|
||||
@@ -803,14 +1106,14 @@ attemptLoop:
|
||||
err = errDo
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||
err = errRead
|
||||
return nil, err
|
||||
@@ -829,7 +1132,24 @@ attemptLoop:
|
||||
err = errRead
|
||||
return nil, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if usedCreditsDirect {
|
||||
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
|
||||
clearAntigravityPreferCredits(auth, baseModel)
|
||||
markAntigravityCreditsExhausted(auth, time.Now())
|
||||
}
|
||||
} else {
|
||||
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
|
||||
if creditsResp != nil {
|
||||
httpResp = creditsResp
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
goto streamSuccessExecuteStream
|
||||
}
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
@@ -851,16 +1171,11 @@ attemptLoop:
|
||||
continue attemptLoop
|
||||
}
|
||||
}
|
||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
streamSuccessExecuteStream:
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
@@ -874,19 +1189,19 @@ attemptLoop:
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
line = helps.FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
payload := helps.JSONPayload(line)
|
||||
if payload == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||
@@ -899,11 +1214,11 @@ attemptLoop:
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
reporter.EnsurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
@@ -911,13 +1226,7 @@ attemptLoop:
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
if lastStatus == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
err = newAntigravityStatusErr(lastStatus, lastBody)
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
@@ -1011,8 +1320,13 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -1026,7 +1340,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
@@ -1040,16 +1354,16 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
@@ -1305,6 +1619,11 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -1316,7 +1635,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
if e.cfg != nil && e.cfg.RequestLog {
|
||||
payloadLog = []byte(payloadStr)
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -1479,7 +1798,7 @@ func antigravityWait(ctx context.Context, wait time.Duration) error {
|
||||
}
|
||||
}
|
||||
|
||||
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
||||
var antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
|
||||
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
||||
return []string{base}
|
||||
}
|
||||
|
||||
423
internal/runtime/executor/antigravity_executor_credits_test.go
Normal file
423
internal/runtime/executor/antigravity_executor_credits_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func resetAntigravityCreditsRetryState() {
|
||||
antigravityCreditsExhaustedByAuth = sync.Map{}
|
||||
antigravityPreferCreditsByModel = sync.Map{}
|
||||
}
|
||||
|
||||
func TestClassifyAntigravity429(t *testing.T) {
|
||||
t.Run("quota exhausted", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
|
||||
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured rate limit", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
if got := classifyAntigravity429(body); got != antigravity429RateLimited {
|
||||
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429RateLimited)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured quota exhausted", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "QUOTA_EXHAUSTED"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
|
||||
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"too many requests"}}`)
|
||||
if got := classifyAntigravity429(body); got != antigravity429Unknown {
|
||||
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429Unknown)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInjectEnabledCreditTypes(t *testing.T) {
|
||||
body := []byte(`{"model":"gemini-2.5-flash","request":{}}`)
|
||||
got := injectEnabledCreditTypes(body)
|
||||
if got == nil {
|
||||
t.Fatal("injectEnabledCreditTypes() returned nil")
|
||||
}
|
||||
if !strings.Contains(string(got), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||
t.Fatalf("injectEnabledCreditTypes() = %s, want enabledCreditTypes", string(got))
|
||||
}
|
||||
|
||||
if got := injectEnabledCreditTypes([]byte(`not json`)); got != nil {
|
||||
t.Fatalf("injectEnabledCreditTypes() for invalid json = %s, want nil", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldMarkAntigravityCreditsExhausted(t *testing.T) {
|
||||
for _, body := range [][]byte{
|
||||
[]byte(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`),
|
||||
[]byte(`{"error":{"message":"minimumCreditAmountForUsage requirement not met"}}`),
|
||||
[]byte(`{"error":{"message":"Resource has been exhausted"}}`),
|
||||
} {
|
||||
if !shouldMarkAntigravityCreditsExhausted(http.StatusForbidden, body, nil) {
|
||||
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body))
|
||||
}
|
||||
}
|
||||
if shouldMarkAntigravityCreditsExhausted(http.StatusServiceUnavailable, []byte(`{"error":{"message":"credits exhausted"}}`), nil) {
|
||||
t.Fatal("shouldMarkAntigravityCreditsExhausted() = true for 5xx, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) {
|
||||
resetAntigravityCreditsRetryState()
|
||||
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
requestBodies []string
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
|
||||
mu.Lock()
|
||||
requestBodies = append(requestBodies, string(body))
|
||||
reqNum := len(requestBodies)
|
||||
mu.Unlock()
|
||||
|
||||
if reqNum == 1 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||
t.Fatalf("second request body missing enabledCreditTypes: %s", string(body))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewAntigravityExecutor(&config.Config{
|
||||
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-credits-ok",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "token",
|
||||
"project_id": "project-1",
|
||||
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gemini-2.5-flash",
|
||||
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatAntigravity,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if len(resp.Payload) == 0 {
|
||||
t.Fatal("Execute() returned empty payload")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(requestBodies) != 2 {
|
||||
t.Fatalf("request count = %d, want 2", len(requestBodies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) {
|
||||
resetAntigravityCreditsRetryState()
|
||||
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||
|
||||
var requestCount int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewAntigravityExecutor(&config.Config{
|
||||
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-credits-exhausted",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "token",
|
||||
"project_id": "project-1",
|
||||
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
markAntigravityCreditsExhausted(auth, time.Now())
|
||||
|
||||
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gemini-2.5-flash",
|
||||
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatAntigravity,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Execute() error = nil, want 429")
|
||||
}
|
||||
sErr, ok := err.(statusErr)
|
||||
if !ok {
|
||||
t.Fatalf("Execute() error type = %T, want statusErr", err)
|
||||
}
|
||||
if got := sErr.StatusCode(); got != http.StatusTooManyRequests {
|
||||
t.Fatalf("Execute() status code = %d, want %d", got, http.StatusTooManyRequests)
|
||||
}
|
||||
if requestCount != 1 {
|
||||
t.Fatalf("request count = %d, want 1", requestCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityExecute_PrefersCreditsAfterSuccessfulFallback(t *testing.T) {
|
||||
resetAntigravityCreditsRetryState()
|
||||
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
requestBodies []string
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
|
||||
mu.Lock()
|
||||
requestBodies = append(requestBodies, string(body))
|
||||
reqNum := len(requestBodies)
|
||||
mu.Unlock()
|
||||
|
||||
switch reqNum {
|
||||
case 1:
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"10s"}]}}`))
|
||||
case 2, 3:
|
||||
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||
t.Fatalf("request %d body missing enabledCreditTypes: %s", reqNum, string(body))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"OK"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
|
||||
default:
|
||||
t.Fatalf("unexpected request count %d", reqNum)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewAntigravityExecutor(&config.Config{
|
||||
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-prefer-credits",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "token",
|
||||
"project_id": "project-1",
|
||||
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
request := cliproxyexecutor.Request{
|
||||
Model: "gemini-2.5-flash",
|
||||
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||
}
|
||||
opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatAntigravity}
|
||||
|
||||
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
|
||||
t.Fatalf("first Execute() error = %v", err)
|
||||
}
|
||||
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
|
||||
t.Fatalf("second Execute() error = %v", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(requestBodies) != 3 {
|
||||
t.Fatalf("request count = %d, want 3", len(requestBodies))
|
||||
}
|
||||
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||
t.Fatalf("first request unexpectedly used credits: %s", requestBodies[0])
|
||||
}
|
||||
if !strings.Contains(requestBodies[1], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||
t.Fatalf("fallback request missing credits: %s", requestBodies[1])
|
||||
}
|
||||
if !strings.Contains(requestBodies[2], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||
t.Fatalf("preferred request missing credits: %s", requestBodies[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityExecute_PreservesBaseURLFallbackAfterCreditsRetryFailure(t *testing.T) {
|
||||
resetAntigravityCreditsRetryState()
|
||||
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
firstCount int
|
||||
secondCount int
|
||||
)
|
||||
|
||||
firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
|
||||
mu.Lock()
|
||||
firstCount++
|
||||
reqNum := firstCount
|
||||
mu.Unlock()
|
||||
|
||||
switch reqNum {
|
||||
case 1:
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"}]}}`))
|
||||
case 2:
|
||||
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||
t.Fatalf("credits retry missing enabledCreditTypes: %s", string(body))
|
||||
}
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"error":{"message":"permission denied"}}`))
|
||||
default:
|
||||
t.Fatalf("unexpected first server request count %d", reqNum)
|
||||
}
|
||||
}))
|
||||
defer firstServer.Close()
|
||||
|
||||
secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
secondCount++
|
||||
mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
|
||||
}))
|
||||
defer secondServer.Close()
|
||||
|
||||
exec := NewAntigravityExecutor(&config.Config{
|
||||
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseurl-fallback",
|
||||
Attributes: map[string]string{
|
||||
"base_url": firstServer.URL,
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "token",
|
||||
"project_id": "project-1",
|
||||
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
originalOrder := antigravityBaseURLFallbackOrder
|
||||
defer func() { antigravityBaseURLFallbackOrder = originalOrder }()
|
||||
antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
|
||||
return []string{firstServer.URL, secondServer.URL}
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gemini-2.5-flash",
|
||||
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatAntigravity,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if len(resp.Payload) == 0 {
|
||||
t.Fatal("Execute() returned empty payload")
|
||||
}
|
||||
if firstCount != 2 {
|
||||
t.Fatalf("first server request count = %d, want 2", firstCount)
|
||||
}
|
||||
if secondCount != 1 {
|
||||
t.Fatalf("second server request count = %d, want 1", secondCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityExecute_DoesNotDirectInjectCreditsWhenFlagDisabled(t *testing.T) {
|
||||
resetAntigravityCreditsRetryState()
|
||||
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||
|
||||
var requestBodies []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
requestBodies = append(requestBodies, string(body))
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewAntigravityExecutor(&config.Config{
|
||||
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: false},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-flag-disabled",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "token",
|
||||
"project_id": "project-1",
|
||||
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
markAntigravityPreferCredits(auth, "gemini-2.5-flash", time.Now(), nil)
|
||||
|
||||
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gemini-2.5-flash",
|
||||
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatAntigravity,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Execute() error = nil, want 429")
|
||||
}
|
||||
if len(requestBodies) != 1 {
|
||||
t.Fatalf("request count = %d, want 1", len(requestBodies))
|
||||
}
|
||||
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||
t.Fatalf("request unexpectedly used enabledCreditTypes with flag disabled: %s", requestBodies[0])
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -18,10 +17,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/google/uuid"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -44,6 +46,10 @@ type ClaudeExecutor struct {
|
||||
// Previously "proxy_" was used but this is a detectable fingerprint difference.
|
||||
const claudeToolPrefix = ""
|
||||
|
||||
// Anthropic-compatible upstreams may reject or even crash when Claude models
|
||||
// omit max_tokens. Prefer registered model metadata before using a fallback.
|
||||
const defaultModelMaxTokens = 1024
|
||||
|
||||
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
||||
|
||||
func (e *ClaudeExecutor) Identifier() string { return "claude" }
|
||||
@@ -86,7 +92,7 @@ func (e *ClaudeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -101,8 +107,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
@@ -125,8 +131,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body = ensureModelMaxTokens(body, baseModel)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -153,6 +160,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
if experimentalCCHSigningEnabled(e.cfg, auth) {
|
||||
bodyForUpstream = signAnthropicMessagesBody(bodyForUpstream)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
|
||||
@@ -166,7 +176,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -178,33 +188,33 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
helps.LogWithRequestID(ctx).Warn(msg)
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
helps.LogWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
@@ -213,7 +223,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
@@ -226,19 +236,19 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}()
|
||||
data, err := io.ReadAll(decodedBody)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if stream {
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
reporter.publish(ctx, parseClaudeUsage(data))
|
||||
reporter.Publish(ctx, helps.ParseClaudeUsage(data))
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||
@@ -269,8 +279,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
originalPayloadSource := req.Payload
|
||||
@@ -291,8 +301,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body = ensureModelMaxTokens(body, baseModel)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -316,6 +327,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
if experimentalCCHSigningEnabled(e.cfg, auth) {
|
||||
bodyForUpstream = signAnthropicMessagesBody(bodyForUpstream)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
|
||||
@@ -329,7 +343,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -341,33 +355,33 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
helps.LogWithRequestID(ctx).Warn(msg)
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
helps.LogWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
@@ -376,7 +390,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
@@ -397,9 +411,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
@@ -411,8 +425,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: cloned}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
return
|
||||
@@ -424,9 +438,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
@@ -446,8 +460,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
@@ -496,7 +510,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -508,32 +522,32 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewUtlsHTTPClient(e.cfg, auth, 0)
|
||||
resp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
helps.LogWithRequestID(ctx).Warn(msg)
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
helps.LogWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
@@ -541,7 +555,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
decodedBody, err := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
@@ -554,10 +568,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}()
|
||||
data, err := io.ReadAll(decodedBody)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
|
||||
@@ -793,13 +807,13 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
|
||||
var deviceProfile claudeDeviceProfile
|
||||
stabilizeDeviceProfile := helps.ClaudeDeviceProfileStabilizationEnabled(cfg)
|
||||
var deviceProfile helps.ClaudeDeviceProfile
|
||||
if stabilizeDeviceProfile {
|
||||
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
|
||||
deviceProfile = helps.ResolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
|
||||
}
|
||||
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05,structured-outputs-2025-12-15,fast-mode-2026-02-01,redact-thinking-2026-02-12,token-efficient-tools-2026-03-28"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
baseBetas = val
|
||||
if !strings.Contains(val, "oauth") {
|
||||
@@ -837,13 +851,22 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
r.Header.Set("Anthropic-Beta", baseBetas)
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
// Only set browser access header for API key mode; real Claude Code CLI does not send it.
|
||||
if useAPIKey {
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
}
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
// Session ID: stable per auth/apiKey, matches Claude Code's X-Claude-Code-Session-Id header.
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Claude-Code-Session-Id", helps.CachedSessionID(apiKey))
|
||||
// Per-request UUID, matches Claude Code's x-client-request-id for first-party API.
|
||||
if isAnthropicBase {
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "x-client-request-id", uuid.New().String())
|
||||
}
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
@@ -858,16 +881,16 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
// Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
|
||||
// to the configured baseline while still allowing newer official
|
||||
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
|
||||
if stabilizeDeviceProfile {
|
||||
helps.ApplyClaudeDeviceProfileHeaders(r, deviceProfile)
|
||||
} else {
|
||||
helps.ApplyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
if stabilizeDeviceProfile {
|
||||
applyClaudeDeviceProfileHeaders(r, deviceProfile)
|
||||
} else {
|
||||
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
|
||||
}
|
||||
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||
@@ -893,7 +916,7 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
}
|
||||
|
||||
func checkSystemInstructions(payload []byte) []byte {
|
||||
return checkSystemInstructionsWithMode(payload, false)
|
||||
return checkSystemInstructionsWithSigningMode(payload, false, false, "2.1.63", "", "")
|
||||
}
|
||||
|
||||
func isClaudeOAuthToken(apiKey string) bool {
|
||||
@@ -1037,7 +1060,7 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
||||
if prefix == "" {
|
||||
return line
|
||||
}
|
||||
payload := jsonPayload(line)
|
||||
payload := helps.JSONPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return line
|
||||
}
|
||||
@@ -1088,6 +1111,38 @@ func getClientUserAgent(ctx context.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseEntrypointFromUA extracts the entrypoint from a Claude Code User-Agent.
|
||||
// Format: "claude-cli/x.y.z (external, cli)" → "cli"
|
||||
// Format: "claude-cli/x.y.z (external, vscode)" → "vscode"
|
||||
// Returns "cli" if parsing fails or UA is not Claude Code.
|
||||
func parseEntrypointFromUA(userAgent string) string {
|
||||
// Find content inside parentheses
|
||||
start := strings.Index(userAgent, "(")
|
||||
end := strings.LastIndex(userAgent, ")")
|
||||
if start < 0 || end <= start {
|
||||
return "cli"
|
||||
}
|
||||
inner := userAgent[start+1 : end]
|
||||
// Split by comma, take the second part (entrypoint is at index 1, after USER_TYPE)
|
||||
// Format: "(USER_TYPE, ENTRYPOINT[, extra...])"
|
||||
parts := strings.Split(inner, ",")
|
||||
if len(parts) >= 2 {
|
||||
ep := strings.TrimSpace(parts[1])
|
||||
if ep != "" {
|
||||
return ep
|
||||
}
|
||||
}
|
||||
return "cli"
|
||||
}
|
||||
|
||||
// getWorkloadFromContext extracts workload identifier from the gin request headers.
|
||||
func getWorkloadFromContext(ctx context.Context) string {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
return strings.TrimSpace(ginCtx.GetHeader("X-CPA-Claude-Workload"))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
||||
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
|
||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
|
||||
@@ -1115,43 +1170,14 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bo
|
||||
return cloakMode, strictMode, sensitiveWords, cacheUserID
|
||||
}
|
||||
|
||||
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
|
||||
if cfg == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey, baseURL := claudeCreds(auth)
|
||||
if apiKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
|
||||
// Match by API key
|
||||
if strings.EqualFold(cfgKey, apiKey) {
|
||||
// If baseURL is specified, also check it
|
||||
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
|
||||
continue
|
||||
}
|
||||
return entry.Cloak
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
||||
// When useCache is false, a new user ID is generated for every call.
|
||||
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||
generateID := func() string {
|
||||
if useCache {
|
||||
return cachedUserID(apiKey)
|
||||
return helps.CachedUserID(apiKey)
|
||||
}
|
||||
return generateFakeUserID()
|
||||
return helps.GenerateFakeUserID()
|
||||
}
|
||||
|
||||
metadata := gjson.GetBytes(payload, "metadata")
|
||||
@@ -1161,38 +1187,84 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||
}
|
||||
|
||||
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||
if existingUserID == "" || !isValidUserID(existingUserID) {
|
||||
if existingUserID == "" || !helps.IsValidUserID(existingUserID) {
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// generateBillingHeader creates the x-anthropic-billing-header text block that
|
||||
// real Claude Code prepends to every system prompt array.
|
||||
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=cli; cch=<hash>;
|
||||
func generateBillingHeader(payload []byte) string {
|
||||
// Generate a deterministic cch hash from the payload content (system + messages + tools).
|
||||
// Real Claude Code uses a 5-char hex hash that varies per request.
|
||||
h := sha256.Sum256(payload)
|
||||
cch := hex.EncodeToString(h[:])[:5]
|
||||
// fingerprintSalt is the salt used by Claude Code to compute the 3-char build fingerprint.
|
||||
const fingerprintSalt = "59cf53e54c78"
|
||||
|
||||
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
|
||||
buildBytes := make([]byte, 2)
|
||||
_, _ = rand.Read(buildBytes)
|
||||
buildHash := hex.EncodeToString(buildBytes)[:3]
|
||||
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
|
||||
// computeFingerprint computes the 3-char build fingerprint that Claude Code embeds in cc_version.
|
||||
// Algorithm: SHA256(salt + messageText[4] + messageText[7] + messageText[20] + version)[:3]
|
||||
func computeFingerprint(messageText, version string) string {
|
||||
indices := [3]int{4, 7, 20}
|
||||
runes := []rune(messageText)
|
||||
var sb strings.Builder
|
||||
for _, idx := range indices {
|
||||
if idx < len(runes) {
|
||||
sb.WriteRune(runes[idx])
|
||||
} else {
|
||||
sb.WriteRune('0')
|
||||
}
|
||||
}
|
||||
input := fingerprintSalt + sb.String() + version
|
||||
h := sha256.Sum256([]byte(input))
|
||||
return hex.EncodeToString(h[:])[:3]
|
||||
}
|
||||
|
||||
// checkSystemInstructionsWithMode injects Claude Code-style system blocks:
|
||||
// generateBillingHeader creates the x-anthropic-billing-header text block that
|
||||
// real Claude Code prepends to every system prompt array.
|
||||
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=<ep>; cch=<hash>; [cc_workload=<wl>;]
|
||||
func generateBillingHeader(payload []byte, experimentalCCHSigning bool, version, messageText, entrypoint, workload string) string {
|
||||
if entrypoint == "" {
|
||||
entrypoint = "cli"
|
||||
}
|
||||
buildHash := computeFingerprint(messageText, version)
|
||||
workloadPart := ""
|
||||
if workload != "" {
|
||||
workloadPart = fmt.Sprintf(" cc_workload=%s;", workload)
|
||||
}
|
||||
|
||||
if experimentalCCHSigning {
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=%s; cch=00000;%s", version, buildHash, entrypoint, workloadPart)
|
||||
}
|
||||
|
||||
// Generate a deterministic cch hash from the payload content (system + messages + tools).
|
||||
h := sha256.Sum256(payload)
|
||||
cch := hex.EncodeToString(h[:])[:5]
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=%s; cch=%s;%s", version, buildHash, entrypoint, cch, workloadPart)
|
||||
}
|
||||
|
||||
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
return checkSystemInstructionsWithSigningMode(payload, strictMode, false, "2.1.63", "", "")
|
||||
}
|
||||
|
||||
// checkSystemInstructionsWithSigningMode injects Claude Code-style system blocks:
|
||||
//
|
||||
// system[0]: billing header (no cache_control)
|
||||
// system[1]: agent identifier (no cache_control)
|
||||
// system[2..]: user system messages (cache_control added when missing)
|
||||
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
func checkSystemInstructionsWithSigningMode(payload []byte, strictMode bool, experimentalCCHSigning bool, version, entrypoint, workload string) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
|
||||
billingText := generateBillingHeader(payload)
|
||||
// Extract original message text for fingerprint computation (before billing injection).
|
||||
// Use the first system text block's content as the fingerprint source.
|
||||
messageText := ""
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
messageText = part.Get("text").String()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if system.Type == gjson.String {
|
||||
messageText = system.String()
|
||||
}
|
||||
|
||||
billingText := generateBillingHeader(payload, experimentalCCHSigning, version, messageText, entrypoint, workload)
|
||||
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
|
||||
// No cache_control on the agent block. It is a cloaking artifact with zero cache
|
||||
// value (the last system block is what actually triggers caching of all system content).
|
||||
@@ -1247,51 +1319,44 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
||||
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
|
||||
clientUserAgent := getClientUserAgent(ctx)
|
||||
useExperimentalCCHSigning := experimentalCCHSigningEnabled(cfg, auth)
|
||||
|
||||
// Get cloak config from ClaudeKey configuration
|
||||
cloakCfg := resolveClaudeKeyCloakConfig(cfg, auth)
|
||||
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
|
||||
|
||||
// Determine cloak settings
|
||||
var cloakMode string
|
||||
var strictMode bool
|
||||
var sensitiveWords []string
|
||||
var cacheUserID bool
|
||||
cloakMode := attrMode
|
||||
strictMode := attrStrict
|
||||
sensitiveWords := attrWords
|
||||
cacheUserID := attrCache
|
||||
|
||||
if cloakCfg != nil {
|
||||
cloakMode = cloakCfg.Mode
|
||||
strictMode = cloakCfg.StrictMode
|
||||
sensitiveWords = cloakCfg.SensitiveWords
|
||||
if mode := strings.TrimSpace(cloakCfg.Mode); mode != "" {
|
||||
cloakMode = mode
|
||||
}
|
||||
if cloakCfg.StrictMode {
|
||||
strictMode = true
|
||||
}
|
||||
if len(cloakCfg.SensitiveWords) > 0 {
|
||||
sensitiveWords = cloakCfg.SensitiveWords
|
||||
}
|
||||
if cloakCfg.CacheUserID != nil {
|
||||
cacheUserID = *cloakCfg.CacheUserID
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to auth attributes if no config found
|
||||
if cloakMode == "" {
|
||||
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
|
||||
cloakMode = attrMode
|
||||
if !strictMode {
|
||||
strictMode = attrStrict
|
||||
}
|
||||
if len(sensitiveWords) == 0 {
|
||||
sensitiveWords = attrWords
|
||||
}
|
||||
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||
cacheUserID = attrCache
|
||||
}
|
||||
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||
_, _, _, attrCache := getCloakConfigFromAuth(auth)
|
||||
cacheUserID = attrCache
|
||||
}
|
||||
|
||||
// Determine if cloaking should be applied
|
||||
if !shouldCloak(cloakMode, clientUserAgent) {
|
||||
if !helps.ShouldCloak(cloakMode, clientUserAgent) {
|
||||
return payload
|
||||
}
|
||||
|
||||
// Skip system instructions for claude-3-5-haiku models
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
payload = checkSystemInstructionsWithMode(payload, strictMode)
|
||||
billingVersion := helps.DefaultClaudeVersion(cfg)
|
||||
entrypoint := parseEntrypointFromUA(clientUserAgent)
|
||||
workload := getWorkloadFromContext(ctx)
|
||||
payload = checkSystemInstructionsWithSigningMode(payload, strictMode, useExperimentalCCHSigning, billingVersion, entrypoint, workload)
|
||||
}
|
||||
|
||||
// Inject fake user ID
|
||||
@@ -1299,8 +1364,8 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
|
||||
// Apply sensitive word obfuscation
|
||||
if len(sensitiveWords) > 0 {
|
||||
matcher := buildSensitiveWordMatcher(sensitiveWords)
|
||||
payload = obfuscateSensitiveWords(payload, matcher)
|
||||
matcher := helps.BuildSensitiveWordMatcher(sensitiveWords)
|
||||
payload = helps.ObfuscateSensitiveWords(payload, matcher)
|
||||
}
|
||||
|
||||
return payload
|
||||
@@ -1310,7 +1375,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
// According to Anthropic's documentation, cache prefixes are created in order: tools -> system -> messages.
|
||||
// This function adds cache_control to:
|
||||
// 1. The LAST tool in the tools array (caches all tool definitions)
|
||||
// 2. The LAST element in the system array (caches system prompt)
|
||||
// 2. The LAST system prompt element
|
||||
// 3. The SECOND-TO-LAST user turn (caches conversation history for multi-turn)
|
||||
//
|
||||
// Up to 4 cache breakpoints are allowed per request. Tools, System, and Messages are INDEPENDENT breakpoints.
|
||||
@@ -1880,3 +1945,26 @@ func injectSystemCacheControl(payload []byte) []byte {
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
func ensureModelMaxTokens(body []byte, modelID string) []byte {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return body
|
||||
}
|
||||
|
||||
if maxTokens := gjson.GetBytes(body, "max_tokens"); maxTokens.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
for _, provider := range registry.GetGlobalRegistry().GetModelProviders(strings.TrimSpace(modelID)) {
|
||||
if strings.EqualFold(provider, "claude") {
|
||||
maxTokens := defaultModelMaxTokens
|
||||
if info := registry.GetGlobalRegistry().GetModelInfo(strings.TrimSpace(modelID), "claude"); info != nil && info.MaxCompletionTokens > 0 {
|
||||
maxTokens = info.MaxCompletionTokens
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", maxTokens)
|
||||
return body
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -14,7 +16,10 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
xxHash64 "github.com/pierrec/xxHash/xxHash64"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -23,9 +28,7 @@ import (
|
||||
)
|
||||
|
||||
func resetClaudeDeviceProfileCache() {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
helps.ResetClaudeDeviceProfileCache()
|
||||
}
|
||||
|
||||
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
|
||||
@@ -98,7 +101,7 @@ func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
|
||||
req := newClaudeHeaderTestRequest(t, incoming)
|
||||
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
assertClaudeFingerprint(t, req.Header, "evil-client/9.9", "9.9.9", "v24.5.0", "Linux", "x64")
|
||||
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
|
||||
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
|
||||
}
|
||||
@@ -338,7 +341,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
|
||||
var pauseOnce sync.Once
|
||||
var releaseOnce sync.Once
|
||||
|
||||
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
|
||||
helps.ClaudeDeviceProfileBeforeCandidateStore = func(candidate helps.ClaudeDeviceProfile) {
|
||||
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
|
||||
return
|
||||
}
|
||||
@@ -346,13 +349,13 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
|
||||
<-releaseLow
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
claudeDeviceProfileBeforeCandidateStore = nil
|
||||
helps.ClaudeDeviceProfileBeforeCandidateStore = nil
|
||||
releaseOnce.Do(func() { close(releaseLow) })
|
||||
})
|
||||
|
||||
lowResultCh := make(chan claudeDeviceProfile, 1)
|
||||
lowResultCh := make(chan helps.ClaudeDeviceProfile, 1)
|
||||
go func() {
|
||||
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
lowResultCh <- helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
@@ -367,7 +370,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
|
||||
t.Fatal("timed out waiting for lower candidate to pause before storing")
|
||||
}
|
||||
|
||||
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
highResult := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||
@@ -398,7 +401,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
|
||||
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
|
||||
}
|
||||
|
||||
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
cached := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
}, cfg)
|
||||
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
@@ -564,7 +567,7 @@ func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *tes
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
|
||||
@@ -591,14 +594,14 @@ func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallbac
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
|
||||
}
|
||||
|
||||
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
|
||||
if claudeDeviceProfileStabilizationEnabled(nil) {
|
||||
if helps.ClaudeDeviceProfileStabilizationEnabled(nil) {
|
||||
t.Fatal("expected nil config to default to disabled stabilization")
|
||||
}
|
||||
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
|
||||
if helps.ClaudeDeviceProfileStabilizationEnabled(&config.Config{}) {
|
||||
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
|
||||
}
|
||||
}
|
||||
@@ -796,8 +799,6 @@ func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
var userIDs []string
|
||||
var requestModels []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -857,15 +858,13 @@ func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
||||
if userIDs[0] != userIDs[1] {
|
||||
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
|
||||
}
|
||||
if !isValidUserID(userIDs[0]) {
|
||||
if !helps.IsValidUserID(userIDs[0]) {
|
||||
t.Fatalf("user_id %q is not valid", userIDs[0])
|
||||
}
|
||||
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
var userIDs []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
@@ -903,7 +902,7 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||
if userIDs[0] == userIDs[1] {
|
||||
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
|
||||
}
|
||||
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
|
||||
if !helps.IsValidUserID(userIDs[0]) || !helps.IsValidUserID(userIDs[1]) {
|
||||
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
|
||||
}
|
||||
}
|
||||
@@ -1183,6 +1182,83 @@ func testClaudeExecutorInvalidCompressedErrorBody(
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureModelMaxTokens_UsesRegisteredMaxCompletionTokens(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-claude-max-completion-tokens-client"
|
||||
modelID := "test-claude-max-completion-tokens-model"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||
ID: modelID,
|
||||
Type: "claude",
|
||||
OwnedBy: "anthropic",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
MaxCompletionTokens: 4096,
|
||||
UserDefined: true,
|
||||
}})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
input := []byte(`{"model":"test-claude-max-completion-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := ensureModelMaxTokens(input, modelID)
|
||||
|
||||
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 4096 {
|
||||
t.Fatalf("max_tokens = %d, want %d", got, 4096)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureModelMaxTokens_DefaultsMissingValue(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-claude-default-max-tokens-client"
|
||||
modelID := "test-claude-default-max-tokens-model"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||
ID: modelID,
|
||||
Type: "claude",
|
||||
OwnedBy: "anthropic",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
UserDefined: true,
|
||||
}})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
input := []byte(`{"model":"test-claude-default-max-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := ensureModelMaxTokens(input, modelID)
|
||||
|
||||
if got := gjson.GetBytes(out, "max_tokens").Int(); got != defaultModelMaxTokens {
|
||||
t.Fatalf("max_tokens = %d, want %d", got, defaultModelMaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureModelMaxTokens_PreservesExplicitValue(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-claude-preserve-max-tokens-client"
|
||||
modelID := "test-claude-preserve-max-tokens-model"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||
ID: modelID,
|
||||
Type: "claude",
|
||||
OwnedBy: "anthropic",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
MaxCompletionTokens: 4096,
|
||||
UserDefined: true,
|
||||
}})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
input := []byte(`{"model":"test-claude-preserve-max-tokens-model","max_tokens":2048,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := ensureModelMaxTokens(input, modelID)
|
||||
|
||||
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 2048 {
|
||||
t.Fatalf("max_tokens = %d, want %d", got, 2048)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureModelMaxTokens_SkipsUnregisteredModel(t *testing.T) {
|
||||
input := []byte(`{"model":"test-claude-unregistered-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := ensureModelMaxTokens(input, "test-claude-unregistered-model")
|
||||
|
||||
if gjson.GetBytes(out, "max_tokens").Exists() {
|
||||
t.Fatalf("max_tokens should remain unset, got %s", gjson.GetBytes(out, "max_tokens").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
||||
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
||||
// compressed SSE body that would silently break the line scanner.
|
||||
@@ -1340,6 +1416,35 @@ func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
|
||||
// detects zstd-compressed content via magic bytes even when Content-Encoding is absent.
|
||||
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
|
||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd.NewWriter: %v", err)
|
||||
}
|
||||
_, _ = enc.Write([]byte(plaintext))
|
||||
_ = enc.Close()
|
||||
|
||||
rc := io.NopCloser(&buf)
|
||||
decoded, err := decodeResponseBody(rc, "")
|
||||
if err != nil {
|
||||
t.Fatalf("decodeResponseBody error: %v", err)
|
||||
}
|
||||
defer decoded.Close()
|
||||
|
||||
got, err := io.ReadAll(decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if string(got) != plaintext {
|
||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
|
||||
// plain text untouched when Content-Encoding is absent and no magic bytes match.
|
||||
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
|
||||
@@ -1411,77 +1516,6 @@ func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
|
||||
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
|
||||
// path's enforced identity encoding.
|
||||
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
|
||||
var gotEncoding string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
// Inject Accept-Encoding via the custom header attribute mechanism.
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||
}
|
||||
}
|
||||
|
||||
if gotEncoding != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
|
||||
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
|
||||
// Content-Encoding is absent.
|
||||
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
|
||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd.NewWriter: %v", err)
|
||||
}
|
||||
_, _ = enc.Write([]byte(plaintext))
|
||||
_ = enc.Close()
|
||||
|
||||
rc := io.NopCloser(&buf)
|
||||
decoded, err := decodeResponseBody(rc, "")
|
||||
if err != nil {
|
||||
t.Fatalf("decodeResponseBody error: %v", err)
|
||||
}
|
||||
defer decoded.Close()
|
||||
|
||||
got, err := io.ReadAll(decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if string(got) != plaintext {
|
||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
|
||||
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
|
||||
// the Content-Encoding header. This closes the gap left by PR #1771, which only
|
||||
@@ -1565,6 +1599,45 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies that the
|
||||
// streaming executor enforces Accept-Encoding: identity regardless of auth.Attributes override.
|
||||
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
|
||||
var gotEncoding string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||
}
|
||||
}
|
||||
|
||||
if gotEncoding != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 1: String system prompt is preserved and converted to a content block
|
||||
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
|
||||
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
@@ -1648,3 +1721,115 @@ func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
|
||||
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExperimentalCCHSigningDisabledByDefaultKeepsLegacyHeader(t *testing.T) {
|
||||
var seenBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
seenBody = bytes.Clone(body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if len(seenBody) == 0 {
|
||||
t.Fatal("expected request body to be captured")
|
||||
}
|
||||
|
||||
billingHeader := gjson.GetBytes(seenBody, "system.0.text").String()
|
||||
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
|
||||
t.Fatalf("system.0.text = %q, want billing header", billingHeader)
|
||||
}
|
||||
if strings.Contains(billingHeader, "cch=00000;") {
|
||||
t.Fatalf("legacy mode should not forward cch placeholder, got %q", billingHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExperimentalCCHSigningOptInSignsFinalBody(t *testing.T) {
|
||||
var seenBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
seenBody = bytes.Clone(body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{{
|
||||
APIKey: "key-123",
|
||||
BaseURL: server.URL,
|
||||
ExperimentalCCHSigning: true,
|
||||
}},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
const messageText = "please keep literal cch=00000 in this message"
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"please keep literal cch=00000 in this message"}]}]}`)
|
||||
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if len(seenBody) == 0 {
|
||||
t.Fatal("expected request body to be captured")
|
||||
}
|
||||
if got := gjson.GetBytes(seenBody, "messages.0.content.0.text").String(); got != messageText {
|
||||
t.Fatalf("message text = %q, want %q", got, messageText)
|
||||
}
|
||||
|
||||
billingPattern := regexp.MustCompile(`(x-anthropic-billing-header:[^"]*?\bcch=)([0-9a-f]{5})(;)`)
|
||||
match := billingPattern.FindSubmatch(seenBody)
|
||||
if match == nil {
|
||||
t.Fatalf("expected signed billing header in body: %s", string(seenBody))
|
||||
}
|
||||
actualCCH := string(match[2])
|
||||
unsignedBody := billingPattern.ReplaceAll(seenBody, []byte(`${1}00000${3}`))
|
||||
wantCCH := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, 0x6E52736AC806831E)&0xFFFFF)
|
||||
if actualCCH != wantCCH {
|
||||
t.Fatalf("cch = %q, want %q\nbody: %s", actualCCH, wantCCH, string(seenBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCloaking_PreservesConfiguredStrictModeAndSensitiveWordsWhenModeOmitted(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{{
|
||||
APIKey: "key-123",
|
||||
Cloak: &config.CloakConfig{
|
||||
StrictMode: true,
|
||||
SensitiveWords: []string{"proxy"},
|
||||
},
|
||||
}},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "key-123"}}
|
||||
payload := []byte(`{"system":"proxy rules","messages":[{"role":"user","content":[{"type":"text","text":"proxy access"}]}]}`)
|
||||
|
||||
out := applyCloaking(context.Background(), cfg, auth, payload, "claude-3-5-sonnet-20241022", "key-123")
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 2 {
|
||||
t.Fatalf("expected strict mode to keep only injected system blocks, got %d", len(blocks))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.text").String(); !strings.Contains(got, "\u200B") {
|
||||
t.Fatalf("expected configured sensitive word obfuscation to apply, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
81
internal/runtime/executor/claude_signing.go
Normal file
81
internal/runtime/executor/claude_signing.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
xxHash64 "github.com/pierrec/xxHash/xxHash64"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const claudeCCHSeed uint64 = 0x6E52736AC806831E
|
||||
|
||||
var claudeBillingHeaderCCHPattern = regexp.MustCompile(`\bcch=([0-9a-f]{5});`)
|
||||
|
||||
func signAnthropicMessagesBody(body []byte) []byte {
|
||||
billingHeader := gjson.GetBytes(body, "system.0.text").String()
|
||||
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
|
||||
return body
|
||||
}
|
||||
if !claudeBillingHeaderCCHPattern.MatchString(billingHeader) {
|
||||
return body
|
||||
}
|
||||
|
||||
unsignedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(billingHeader, "cch=00000;")
|
||||
unsignedBody, err := sjson.SetBytes(body, "system.0.text", unsignedBillingHeader)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
cch := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, claudeCCHSeed)&0xFFFFF)
|
||||
signedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(unsignedBillingHeader, "cch="+cch+";")
|
||||
signedBody, err := sjson.SetBytes(unsignedBody, "system.0.text", signedBillingHeader)
|
||||
if err != nil {
|
||||
return unsignedBody
|
||||
}
|
||||
return signedBody
|
||||
}
|
||||
|
||||
func resolveClaudeKeyConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.ClaudeKey {
|
||||
if cfg == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey, baseURL := claudeCreds(auth)
|
||||
if apiKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if !strings.EqualFold(cfgKey, apiKey) {
|
||||
continue
|
||||
}
|
||||
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
|
||||
continue
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
|
||||
entry := resolveClaudeKeyConfig(cfg, auth)
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return entry.Cloak
|
||||
}
|
||||
|
||||
func experimentalCCHSigningEnabled(cfg *config.Config, auth *cliproxyauth.Auth) bool {
|
||||
entry := resolveClaudeKeyConfig(cfg, auth)
|
||||
return entry != nil && entry.ExperimentalCCHSigning
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -28,8 +29,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexOriginator = "codex_cli_rs"
|
||||
codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
|
||||
codexOriginator = "codex-tui"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
@@ -73,7 +74,7 @@ func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -88,8 +89,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
@@ -106,16 +107,15 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body = normalizeCodexInstructions(body)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -129,7 +129,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -140,10 +140,10 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
@@ -151,20 +151,20 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
@@ -177,8 +177,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := parseCodexUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
if detail, ok := helps.ParseCodexUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
|
||||
var param any
|
||||
@@ -198,8 +198,8 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai-response")
|
||||
@@ -216,10 +216,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.DeleteBytes(body, "stream")
|
||||
body = normalizeCodexInstructions(body)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -233,7 +234,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -244,10 +245,10 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
@@ -255,22 +256,22 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
reporter.ensurePublished(ctx)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||
reporter.EnsurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
@@ -288,8 +289,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
@@ -306,15 +307,14 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
body = normalizeCodexInstructions(body)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -328,7 +328,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -340,24 +340,24 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
data, readErr := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||
}
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
|
||||
return nil, readErr
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||
return nil, err
|
||||
}
|
||||
@@ -374,13 +374,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
data := bytes.TrimSpace(line[5:])
|
||||
if gjson.GetBytes(data, "type").String() == "response.completed" {
|
||||
if detail, ok := parseCodexUsage(data); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
if detail, ok := helps.ParseCodexUsage(data); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -391,8 +391,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
@@ -415,10 +415,9 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
body = normalizeCodexInstructions(body)
|
||||
|
||||
enc, err := tokenizerForCodexModel(baseModel)
|
||||
if err != nil {
|
||||
@@ -597,18 +596,18 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
|
||||
var cache codexCache
|
||||
var cache helps.CodexCache
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
||||
var ok bool
|
||||
if cache, ok = getCodexCache(key); !ok {
|
||||
cache = codexCache{
|
||||
if cache, ok = helps.GetCodexCache(key); !ok {
|
||||
cache = helps.CodexCache{
|
||||
ID: uuid.New().String(),
|
||||
Expire: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
setCodexCache(key, cache)
|
||||
helps.SetCodexCache(key, cache)
|
||||
}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
@@ -617,7 +616,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
cache.ID = promptCacheKey.String()
|
||||
}
|
||||
} else if from == "openai" {
|
||||
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
|
||||
if apiKey := strings.TrimSpace(helps.APIKeyFromContext(ctx)); apiKey != "" {
|
||||
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
||||
}
|
||||
}
|
||||
@@ -630,7 +629,6 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
return nil, err
|
||||
}
|
||||
if cache.ID != "" {
|
||||
httpReq.Header.Set("Conversation_id", cache.ID)
|
||||
httpReq.Header.Set("Session_id", cache.ID)
|
||||
}
|
||||
return httpReq, nil
|
||||
@@ -645,13 +643,19 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
if ginHeaders.Get("X-Codex-Beta-Features") != "" {
|
||||
r.Header.Set("X-Codex-Beta-Features", ginHeaders.Get("X-Codex-Beta-Features"))
|
||||
}
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
|
||||
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
if strings.Contains(r.Header.Get("User-Agent"), "Mac OS") {
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
}
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
@@ -685,13 +689,47 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
errCode := statusCode
|
||||
if isCodexModelCapacityError(body) {
|
||||
errCode = http.StatusTooManyRequests
|
||||
}
|
||||
err := statusErr{code: errCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func normalizeCodexInstructions(body []byte) []byte {
|
||||
instructions := gjson.GetBytes(body, "instructions")
|
||||
if !instructions.Exists() || instructions.Type == gjson.Null {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func isCodexModelCapacityError(errorBody []byte) bool {
|
||||
if len(errorBody) == 0 {
|
||||
return false
|
||||
}
|
||||
candidates := []string{
|
||||
gjson.GetBytes(errorBody, "error.message").String(),
|
||||
gjson.GetBytes(errorBody, "message").String(),
|
||||
string(errorBody),
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
lower := strings.ToLower(strings.TrimSpace(candidate))
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(lower, "selected model is at capacity") ||
|
||||
strings.Contains(lower, "model is at capacity. please try a different model") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -42,8 +42,8 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
if gotKey != expectedKey {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
||||
}
|
||||
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
|
||||
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
|
||||
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", gotConversation)
|
||||
}
|
||||
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
||||
|
||||
79
internal/runtime/executor/codex_executor_compact_test.go
Normal file
79
internal/runtime/executor/codex_executor_compact_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestCodexExecutorCompactAddsDefaultInstructions(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payload string
|
||||
}{
|
||||
{
|
||||
name: "missing instructions",
|
||||
payload: `{"model":"gpt-5.4","input":"hello"}`,
|
||||
},
|
||||
{
|
||||
name: "null instructions",
|
||||
payload: `{"model":"gpt-5.4","instructions":null,"input":"hello"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody = body
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewCodexExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"api_key": "test",
|
||||
}}
|
||||
|
||||
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(tc.payload),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||
Alt: "responses/compact",
|
||||
Stream: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
if gotPath != "/responses/compact" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/responses/compact")
|
||||
}
|
||||
if !gjson.GetBytes(gotBody, "instructions").Exists() {
|
||||
t.Fatalf("expected instructions in compact request body, got %s", string(gotBody))
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
|
||||
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "instructions").String() != "" {
|
||||
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
|
||||
}
|
||||
if string(resp.Payload) != `{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}` {
|
||||
t.Fatalf("payload = %s", string(resp.Payload))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
123
internal/runtime/executor/codex_executor_instructions_test.go
Normal file
123
internal/runtime/executor/codex_executor_instructions_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestCodexExecutorExecuteNormalizesNullInstructions(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody = body
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewCodexExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"api_key": "test",
|
||||
}}
|
||||
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||
Stream: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
if gotPath != "/responses" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/responses")
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
|
||||
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "instructions").String() != "" {
|
||||
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorExecuteStreamNormalizesNullInstructions(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody = body
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewCodexExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"api_key": "test",
|
||||
}}
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||
Stream: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
for range result.Chunks {
|
||||
}
|
||||
if gotPath != "/responses" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/responses")
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
|
||||
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "instructions").String() != "" {
|
||||
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCountTokensTreatsNullInstructionsAsEmpty(t *testing.T) {
|
||||
executor := NewCodexExecutor(&config.Config{})
|
||||
|
||||
nullResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens(null) error: %v", err)
|
||||
}
|
||||
|
||||
emptyResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4","instructions":"","input":"hello"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens(empty) error: %v", err)
|
||||
}
|
||||
|
||||
if string(nullResp.Payload) != string(emptyResp.Payload) {
|
||||
t.Fatalf("token count payload mismatch:\nnull=%s\nempty=%s", string(nullResp.Payload), string(emptyResp.Payload))
|
||||
}
|
||||
}
|
||||
@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
|
||||
|
||||
err := newCodexStatusErr(http.StatusBadRequest, body)
|
||||
|
||||
if got := err.StatusCode(); got != http.StatusTooManyRequests {
|
||||
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
|
||||
}
|
||||
if err.RetryAfter() != nil {
|
||||
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
@@ -15,10 +15,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -44,10 +46,18 @@ const (
|
||||
type CodexWebsocketsExecutor struct {
|
||||
*CodexExecutor
|
||||
|
||||
sessMu sync.Mutex
|
||||
store *codexWebsocketSessionStore
|
||||
}
|
||||
|
||||
type codexWebsocketSessionStore struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*codexWebsocketSession
|
||||
}
|
||||
|
||||
var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{
|
||||
sessions: make(map[string]*codexWebsocketSession),
|
||||
}
|
||||
|
||||
type codexWebsocketSession struct {
|
||||
sessionID string
|
||||
|
||||
@@ -71,7 +81,7 @@ type codexWebsocketSession struct {
|
||||
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
||||
return &CodexWebsocketsExecutor{
|
||||
CodexExecutor: NewCodexExecutor(cfg),
|
||||
sessions: make(map[string]*codexWebsocketSession),
|
||||
store: globalCodexWebsocketSessionStore,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,8 +165,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
@@ -173,8 +183,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
@@ -209,7 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
wsReqLog := helps.UpstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
Headers: wsHeaders.Clone(),
|
||||
@@ -219,16 +229,14 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
}
|
||||
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
|
||||
|
||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if respHS != nil {
|
||||
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
||||
}
|
||||
if errDial != nil {
|
||||
bodyErr := websocketHandshakeBody(respHS)
|
||||
if len(bodyErr) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
|
||||
if respHS != nil {
|
||||
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
|
||||
}
|
||||
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
||||
return e.CodexExecutor.Execute(ctx, auth, req, opts)
|
||||
@@ -236,10 +244,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
if respHS != nil && respHS.StatusCode > 0 {
|
||||
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
||||
}
|
||||
recordAPIResponseError(ctx, e.cfg, errDial)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
|
||||
return resp, errDial
|
||||
}
|
||||
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
|
||||
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
|
||||
if sess == nil {
|
||||
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
||||
defer func() {
|
||||
@@ -268,10 +276,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
// Retry once with a fresh websocket connection. This is mainly to handle
|
||||
// upstream closing the socket between sequential requests within the same
|
||||
// execution session.
|
||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if errDialRetry == nil && connRetry != nil {
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
Headers: wsHeaders.Clone(),
|
||||
@@ -282,20 +290,22 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
|
||||
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
|
||||
conn = connRetry
|
||||
wsReqBody = wsReqBodyRetry
|
||||
} else {
|
||||
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
||||
recordAPIResponseError(ctx, e.cfg, errSendRetry)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
|
||||
return resp, errSendRetry
|
||||
}
|
||||
} else {
|
||||
recordAPIResponseError(ctx, e.cfg, errDialRetry)
|
||||
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
|
||||
return resp, errDialRetry
|
||||
}
|
||||
} else {
|
||||
recordAPIResponseError(ctx, e.cfg, errSend)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
|
||||
return resp, errSend
|
||||
}
|
||||
}
|
||||
@@ -306,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
if msgType != websocket.TextMessage {
|
||||
@@ -315,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
||||
}
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
|
||||
return resp, err
|
||||
}
|
||||
continue
|
||||
@@ -325,21 +335,21 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
if len(payload) == 0 {
|
||||
continue
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, payload)
|
||||
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
|
||||
|
||||
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
||||
}
|
||||
recordAPIResponseError(ctx, e.cfg, wsErr)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
|
||||
return resp, wsErr
|
||||
}
|
||||
|
||||
payload = normalizeCodexWebsocketCompletion(payload)
|
||||
eventType := gjson.GetBytes(payload, "type").String()
|
||||
if eventType == "response.completed" {
|
||||
if detail, ok := parseCodexUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
if detail, ok := helps.ParseCodexUsage(payload); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
|
||||
@@ -364,8 +374,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
@@ -376,8 +386,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
|
||||
|
||||
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
|
||||
@@ -403,7 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
wsReqLog := helps.UpstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
Headers: wsHeaders.Clone(),
|
||||
@@ -413,18 +423,18 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
}
|
||||
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
|
||||
|
||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
var upstreamHeaders http.Header
|
||||
if respHS != nil {
|
||||
upstreamHeaders = respHS.Header.Clone()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
||||
}
|
||||
if errDial != nil {
|
||||
bodyErr := websocketHandshakeBody(respHS)
|
||||
if len(bodyErr) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
|
||||
if respHS != nil {
|
||||
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
|
||||
}
|
||||
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
||||
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
|
||||
@@ -432,13 +442,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
if respHS != nil && respHS.StatusCode > 0 {
|
||||
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
||||
}
|
||||
recordAPIResponseError(ctx, e.cfg, errDial)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
|
||||
if sess != nil {
|
||||
sess.reqMu.Unlock()
|
||||
}
|
||||
return nil, errDial
|
||||
}
|
||||
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
|
||||
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
|
||||
|
||||
if sess == nil {
|
||||
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
||||
@@ -451,20 +461,21 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errSend)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
|
||||
|
||||
// Retry once with a new websocket connection for the same execution session.
|
||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if errDialRetry != nil || connRetry == nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDialRetry)
|
||||
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
|
||||
sess.clearActive(readCh)
|
||||
sess.reqMu.Unlock()
|
||||
return nil, errDialRetry
|
||||
}
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
Headers: wsHeaders.Clone(),
|
||||
@@ -475,8 +486,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
|
||||
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errSendRetry)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
|
||||
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
||||
sess.clearActive(readCh)
|
||||
sess.reqMu.Unlock()
|
||||
@@ -542,8 +554,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
terminateReason = "read_error"
|
||||
terminateErr = errRead
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
|
||||
reporter.PublishFailure(ctx)
|
||||
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
|
||||
return
|
||||
}
|
||||
@@ -552,8 +564,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
err = fmt.Errorf("codex websockets executor: unexpected binary message")
|
||||
terminateReason = "unexpected_binary"
|
||||
terminateErr = err
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
|
||||
reporter.PublishFailure(ctx)
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
||||
}
|
||||
@@ -567,13 +579,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
if len(payload) == 0 {
|
||||
continue
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, payload)
|
||||
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
|
||||
|
||||
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
||||
terminateReason = "upstream_error"
|
||||
terminateErr = wsErr
|
||||
recordAPIResponseError(ctx, e.cfg, wsErr)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
|
||||
reporter.PublishFailure(ctx)
|
||||
if sess != nil {
|
||||
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
||||
}
|
||||
@@ -584,8 +596,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
payload = normalizeCodexWebsocketCompletion(payload)
|
||||
eventType := gjson.GetBytes(payload, "type").String()
|
||||
if eventType == "response.completed" || eventType == "response.done" {
|
||||
if detail, ok := parseCodexUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
if detail, ok := helps.ParseCodexUsage(payload); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -767,19 +779,19 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
||||
return rawJSON, headers
|
||||
}
|
||||
|
||||
var cache codexCache
|
||||
var cache helps.CodexCache
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
||||
if cached, ok := getCodexCache(key); ok {
|
||||
if cached, ok := helps.GetCodexCache(key); ok {
|
||||
cache = cached
|
||||
} else {
|
||||
cache = codexCache{
|
||||
cache = helps.CodexCache{
|
||||
ID: uuid.New().String(),
|
||||
Expire: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
setCodexCache(key, cache)
|
||||
helps.SetCodexCache(key, cache)
|
||||
}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
@@ -791,7 +803,6 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
||||
if cache.ID != "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||
headers.Set("Conversation_id", cache.ID)
|
||||
headers.Set("Session_id", cache.ID)
|
||||
}
|
||||
|
||||
return rawJSON, headers
|
||||
@@ -806,11 +817,11 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
}
|
||||
|
||||
var ginHeaders http.Header
|
||||
if ginCtx := ginContextFrom(ctx); ginCtx != nil && ginCtx.Request != nil {
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
ginHeaders = ginCtx.Request.Header.Clone()
|
||||
}
|
||||
|
||||
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
||||
_, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
||||
@@ -826,8 +837,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
betaHeader = codexResponsesWebsocketBetaHeaderValue
|
||||
}
|
||||
headers.Set("OpenAI-Beta", betaHeader)
|
||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
|
||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||
}
|
||||
headers.Del("User-Agent")
|
||||
|
||||
isAPIKey := false
|
||||
if auth != nil && auth.Attributes != nil {
|
||||
@@ -1011,6 +1024,32 @@ func encodeCodexWebsocketAsSSE(payload []byte) []byte {
|
||||
return line
|
||||
}
|
||||
|
||||
func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog {
|
||||
upgradeInfo := info
|
||||
upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL)
|
||||
upgradeInfo.Method = http.MethodGet
|
||||
upgradeInfo.Body = nil
|
||||
upgradeInfo.Headers = info.Headers.Clone()
|
||||
if upgradeInfo.Headers == nil {
|
||||
upgradeInfo.Headers = make(http.Header)
|
||||
}
|
||||
if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" {
|
||||
upgradeInfo.Headers.Set("Connection", "Upgrade")
|
||||
}
|
||||
if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" {
|
||||
upgradeInfo.Headers.Set("Upgrade", "websocket")
|
||||
}
|
||||
return upgradeInfo
|
||||
}
|
||||
|
||||
func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone())
|
||||
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
|
||||
}
|
||||
|
||||
func websocketHandshakeBody(resp *http.Response) []byte {
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil
|
||||
@@ -1055,16 +1094,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
|
||||
if sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
e.sessMu.Lock()
|
||||
defer e.sessMu.Unlock()
|
||||
if e.sessions == nil {
|
||||
e.sessions = make(map[string]*codexWebsocketSession)
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
if sess, ok := e.sessions[sessionID]; ok && sess != nil {
|
||||
store := e.store
|
||||
if store == nil {
|
||||
store = globalCodexWebsocketSessionStore
|
||||
}
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
if store.sessions == nil {
|
||||
store.sessions = make(map[string]*codexWebsocketSession)
|
||||
}
|
||||
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
|
||||
return sess
|
||||
}
|
||||
sess := &codexWebsocketSession{sessionID: sessionID}
|
||||
e.sessions[sessionID] = sess
|
||||
store.sessions[sessionID] = sess
|
||||
return sess
|
||||
}
|
||||
|
||||
@@ -1210,14 +1256,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
|
||||
return
|
||||
}
|
||||
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
|
||||
e.closeAllExecutionSessions("executor_replaced")
|
||||
// Executor replacement can happen during hot reload (config/credential changes).
|
||||
// Do not force-close upstream websocket sessions here, otherwise in-flight
|
||||
// downstream websocket requests get interrupted.
|
||||
return
|
||||
}
|
||||
|
||||
e.sessMu.Lock()
|
||||
sess := e.sessions[sessionID]
|
||||
delete(e.sessions, sessionID)
|
||||
e.sessMu.Unlock()
|
||||
store := e.store
|
||||
if store == nil {
|
||||
store = globalCodexWebsocketSessionStore
|
||||
}
|
||||
store.mu.Lock()
|
||||
sess := store.sessions[sessionID]
|
||||
delete(store.sessions, sessionID)
|
||||
store.mu.Unlock()
|
||||
|
||||
e.closeExecutionSession(sess, "session_closed")
|
||||
}
|
||||
@@ -1227,15 +1279,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
||||
return
|
||||
}
|
||||
|
||||
e.sessMu.Lock()
|
||||
sessions := make([]*codexWebsocketSession, 0, len(e.sessions))
|
||||
for sessionID, sess := range e.sessions {
|
||||
delete(e.sessions, sessionID)
|
||||
store := e.store
|
||||
if store == nil {
|
||||
store = globalCodexWebsocketSessionStore
|
||||
}
|
||||
store.mu.Lock()
|
||||
sessions := make([]*codexWebsocketSession, 0, len(store.sessions))
|
||||
for sessionID, sess := range store.sessions {
|
||||
delete(store.sessions, sessionID)
|
||||
if sess != nil {
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
}
|
||||
e.sessMu.Unlock()
|
||||
store.mu.Unlock()
|
||||
|
||||
for i := range sessions {
|
||||
e.closeExecutionSession(sessions[i], reason)
|
||||
@@ -1243,6 +1299,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
||||
}
|
||||
|
||||
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
|
||||
closeCodexWebsocketSession(sess, reason)
|
||||
}
|
||||
|
||||
func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) {
|
||||
if sess == nil {
|
||||
return
|
||||
}
|
||||
@@ -1283,6 +1343,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string
|
||||
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
|
||||
}
|
||||
|
||||
// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions
|
||||
// associated with the supplied auth ID.
|
||||
func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) {
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
reason = strings.TrimSpace(reason)
|
||||
if reason == "" {
|
||||
reason = "auth_removed"
|
||||
}
|
||||
|
||||
store := globalCodexWebsocketSessionStore
|
||||
if store == nil {
|
||||
return
|
||||
}
|
||||
|
||||
type sessionItem struct {
|
||||
sessionID string
|
||||
sess *codexWebsocketSession
|
||||
}
|
||||
|
||||
store.mu.Lock()
|
||||
items := make([]sessionItem, 0, len(store.sessions))
|
||||
for sessionID, sess := range store.sessions {
|
||||
items = append(items, sessionItem{sessionID: sessionID, sess: sess})
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
matches := make([]sessionItem, 0)
|
||||
for i := range items {
|
||||
sess := items[i].sess
|
||||
if sess == nil {
|
||||
continue
|
||||
}
|
||||
sess.connMu.Lock()
|
||||
sessAuthID := strings.TrimSpace(sess.authID)
|
||||
sess.connMu.Unlock()
|
||||
if sessAuthID == authID {
|
||||
matches = append(matches, items[i])
|
||||
}
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
toClose := make([]*codexWebsocketSession, 0, len(matches))
|
||||
store.mu.Lock()
|
||||
for i := range matches {
|
||||
current, ok := store.sessions[matches[i].sessionID]
|
||||
if !ok || current == nil || current != matches[i].sess {
|
||||
continue
|
||||
}
|
||||
delete(store.sessions, matches[i].sessionID)
|
||||
toClose = append(toClose, current)
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
for i := range toClose {
|
||||
closeCodexWebsocketSession(toClose[i], reason)
|
||||
}
|
||||
}
|
||||
|
||||
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
|
||||
// 1. The downstream transport is websocket, and
|
||||
// 2. The selected auth enables websockets.
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) {
|
||||
sessionID := "test-session-store-survives-replace"
|
||||
|
||||
globalCodexWebsocketSessionStore.mu.Lock()
|
||||
delete(globalCodexWebsocketSessionStore.sessions, sessionID)
|
||||
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||
|
||||
exec1 := NewCodexWebsocketsExecutor(nil)
|
||||
sess1 := exec1.getOrCreateSession(sessionID)
|
||||
if sess1 == nil {
|
||||
t.Fatalf("expected session to be created")
|
||||
}
|
||||
|
||||
exec2 := NewCodexWebsocketsExecutor(nil)
|
||||
sess2 := exec2.getOrCreateSession(sessionID)
|
||||
if sess2 == nil {
|
||||
t.Fatalf("expected session to be available across executors")
|
||||
}
|
||||
if sess1 != sess2 {
|
||||
t.Fatalf("expected the same session instance across executors")
|
||||
}
|
||||
|
||||
exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID)
|
||||
|
||||
globalCodexWebsocketSessionStore.mu.Lock()
|
||||
_, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID]
|
||||
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||
if !stillPresent {
|
||||
t.Fatalf("expected session to remain after executor replacement close marker")
|
||||
}
|
||||
|
||||
exec2.CloseExecutionSession(sessionID)
|
||||
|
||||
globalCodexWebsocketSessionStore.mu.Lock()
|
||||
_, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID]
|
||||
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||
if presentAfterClose {
|
||||
t.Fatalf("expected session to be removed after explicit close")
|
||||
}
|
||||
}
|
||||
@@ -38,8 +38,8 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T)
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
if got := headers.Get("User-Agent"); got != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", got)
|
||||
}
|
||||
if got := headers.Get("Version"); got != "" {
|
||||
t.Fatalf("Version = %q, want empty", got)
|
||||
@@ -97,8 +97,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
||||
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
|
||||
if got := headers.Get("User-Agent"); got != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", got)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
|
||||
@@ -129,8 +129,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
|
||||
|
||||
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
|
||||
|
||||
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
|
||||
if gotVal := got.Get("User-Agent"); gotVal != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", gotVal)
|
||||
}
|
||||
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
|
||||
@@ -155,8 +155,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
|
||||
|
||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "config-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
||||
if got := headers.Get("User-Agent"); got != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", got)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
|
||||
@@ -177,8 +177,8 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
|
||||
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
if got := headers.Get("User-Agent"); got != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", got)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
|
||||
129
internal/runtime/executor/compat_helpers.go
Normal file
129
internal/runtime/executor/compat_helpers.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
}
|
||||
|
||||
func parseOpenAIUsage(data []byte) usage.Detail {
|
||||
return helps.ParseOpenAIUsage(data)
|
||||
}
|
||||
|
||||
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return helps.ParseOpenAIStreamUsage(line)
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
|
||||
return helps.ParseOpenAIUsage(data)
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return helps.ParseOpenAIStreamUsage(line)
|
||||
}
|
||||
|
||||
func getTokenizer(model string) (tokenizer.Codec, error) {
|
||||
return helps.TokenizerForModel(model)
|
||||
}
|
||||
|
||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return helps.CountOpenAIChatTokens(enc, payload)
|
||||
}
|
||||
|
||||
func countClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return helps.CountClaudeChatTokens(enc, payload)
|
||||
}
|
||||
|
||||
func buildOpenAIUsageJSON(count int64) []byte {
|
||||
return helps.BuildOpenAIUsageJSON(count)
|
||||
}
|
||||
|
||||
type upstreamRequestLog = helps.UpstreamRequestLog
|
||||
|
||||
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
|
||||
helps.RecordAPIRequest(ctx, cfg, info)
|
||||
}
|
||||
|
||||
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
helps.RecordAPIResponseMetadata(ctx, cfg, status, headers)
|
||||
}
|
||||
|
||||
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
|
||||
helps.RecordAPIResponseError(ctx, cfg, err)
|
||||
}
|
||||
|
||||
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
|
||||
helps.AppendAPIResponseChunk(ctx, cfg, chunk)
|
||||
}
|
||||
|
||||
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||
return helps.PayloadRequestedModel(opts, fallback)
|
||||
}
|
||||
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||
return helps.ApplyPayloadConfigWithRoot(cfg, model, protocol, root, payload, original, requestedModel)
|
||||
}
|
||||
|
||||
func summarizeErrorBody(contentType string, body []byte) string {
|
||||
return helps.SummarizeErrorBody(contentType, body)
|
||||
}
|
||||
|
||||
func apiKeyFromContext(ctx context.Context) string {
|
||||
return helps.APIKeyFromContext(ctx)
|
||||
}
|
||||
|
||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
return helps.TokenizerForModel(model)
|
||||
}
|
||||
|
||||
func collectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||
helps.CollectOpenAIContent(content, segments)
|
||||
}
|
||||
|
||||
type usageReporter struct {
|
||||
reporter *helps.UsageReporter
|
||||
}
|
||||
|
||||
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
|
||||
return &usageReporter{reporter: helps.NewUsageReporter(ctx, provider, model, auth)}
|
||||
}
|
||||
|
||||
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
|
||||
if r == nil || r.reporter == nil {
|
||||
return
|
||||
}
|
||||
r.reporter.Publish(ctx, detail)
|
||||
}
|
||||
|
||||
func (r *usageReporter) publishFailure(ctx context.Context) {
|
||||
if r == nil || r.reporter == nil {
|
||||
return
|
||||
}
|
||||
r.reporter.PublishFailure(ctx)
|
||||
}
|
||||
|
||||
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
|
||||
if r == nil || r.reporter == nil {
|
||||
return
|
||||
}
|
||||
r.reporter.TrackFailure(ctx, errPtr)
|
||||
}
|
||||
|
||||
func (r *usageReporter) ensurePublished(ctx context.Context) {
|
||||
if r == nil || r.reporter == nil {
|
||||
return
|
||||
}
|
||||
r.reporter.EnsurePublished(ctx)
|
||||
}
|
||||
1719
internal/runtime/executor/cursor_executor.go
Normal file
1719
internal/runtime/executor/cursor_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -81,6 +82,11 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(req, "unknown")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,8 +118,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
return resp, err
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
@@ -132,8 +138,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -190,7 +196,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||
reqHTTP.Header.Set("Accept", "application/json")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: reqHTTP.Header.Clone(),
|
||||
@@ -204,7 +211,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
httpResp, errDo := httpClient.Do(reqHTTP)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
err = errDo
|
||||
return resp, err
|
||||
}
|
||||
@@ -213,15 +220,15 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
@@ -230,7 +237,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
if httpResp.StatusCode == 429 {
|
||||
if idx+1 < len(models) {
|
||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||
@@ -245,7 +252,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
if len(lastBody) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, lastBody)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
|
||||
}
|
||||
if lastStatus == 0 {
|
||||
lastStatus = 429
|
||||
@@ -266,8 +273,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
@@ -286,8 +293,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||
|
||||
projectID := resolveGeminiProjectID(auth)
|
||||
|
||||
@@ -335,7 +342,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||
reqHTTP.Header.Set("Accept", "text/event-stream")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: reqHTTP.Header.Clone(),
|
||||
@@ -349,25 +357,25 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
httpResp, errDo := httpClient.Do(reqHTTP)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
err = errDo
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
err = errRead
|
||||
return nil, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
if httpResp.StatusCode == 429 {
|
||||
if idx+1 < len(models) {
|
||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||
@@ -394,9 +402,9 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseGeminiCLIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseGeminiCLIStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||
@@ -411,8 +419,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
return
|
||||
@@ -420,13 +428,13 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
data, errRead := io.ReadAll(resp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errRead}
|
||||
return
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
@@ -443,7 +451,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
if len(lastBody) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, lastBody)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
|
||||
}
|
||||
if lastStatus == 0 {
|
||||
lastStatus = 429
|
||||
@@ -516,7 +524,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP, baseModel)
|
||||
reqHTTP.Header.Set("Accept", "application/json")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: reqHTTP.Header.Clone(),
|
||||
@@ -530,17 +539,19 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
resp, errDo := httpClient.Do(reqHTTP)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
data, errRead := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
@@ -611,7 +622,7 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
|
||||
}
|
||||
|
||||
ctxToken := ctx
|
||||
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
||||
if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
||||
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
|
||||
}
|
||||
|
||||
@@ -707,7 +718,7 @@ func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any {
|
||||
}
|
||||
|
||||
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -85,7 +86,7 @@ func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -110,8 +111,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
|
||||
apiKey, bearer := geminiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
@@ -130,8 +131,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := "generateContent"
|
||||
@@ -165,7 +166,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -177,10 +178,10 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
@@ -188,21 +189,21 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
log.Errorf("gemini executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
@@ -218,8 +219,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
apiKey, bearer := geminiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
@@ -237,8 +238,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
@@ -268,7 +269,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -280,17 +281,17 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("gemini executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -310,14 +311,14 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
filtered := FilterSSEUsageMetadata(line)
|
||||
payload := jsonPayload(filtered)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
filtered := helps.FilterSSEUsageMetadata(line)
|
||||
payload := helps.JSONPayload(filtered)
|
||||
if len(payload) == 0 {
|
||||
continue
|
||||
}
|
||||
if detail, ok := parseGeminiStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
if detail, ok := helps.ParseGeminiStreamUsage(payload); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||
for i := range lines {
|
||||
@@ -329,8 +330,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
@@ -381,7 +382,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -393,23 +394,27 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
resp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, helps.SummarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@ import (
|
||||
|
||||
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -227,7 +229,7 @@ func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -301,8 +303,8 @@ func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Aut
|
||||
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
var body []byte
|
||||
|
||||
@@ -332,8 +334,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
}
|
||||
|
||||
@@ -362,6 +364,11 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
return resp, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -369,7 +376,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -381,10 +388,10 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return resp, errDo
|
||||
}
|
||||
defer func() {
|
||||
@@ -392,21 +399,21 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
|
||||
|
||||
// For Imagen models, convert response to Gemini format before translation
|
||||
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
|
||||
@@ -427,8 +434,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
@@ -447,8 +454,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, false)
|
||||
@@ -477,6 +484,11 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -484,7 +496,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -496,10 +508,10 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return resp, errDo
|
||||
}
|
||||
defer func() {
|
||||
@@ -507,21 +519,21 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
@@ -532,8 +544,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
@@ -552,8 +564,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
@@ -581,6 +593,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
return nil, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -588,7 +605,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -600,17 +617,17 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return nil, errDo
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -630,9 +647,9 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
@@ -644,8 +661,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
@@ -656,8 +673,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
@@ -676,8 +693,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
@@ -705,6 +722,11 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -712,7 +734,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -724,17 +746,17 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return nil, errDo
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -754,9 +776,9 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
@@ -768,8 +790,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
@@ -812,6 +834,11 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -819,7 +846,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -831,10 +858,10 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
defer func() {
|
||||
@@ -842,19 +869,19 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
@@ -896,6 +923,11 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -903,7 +935,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -915,10 +947,10 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
defer func() {
|
||||
@@ -926,19 +958,19 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
@@ -1012,7 +1044,7 @@ func vertexBaseURL(location string) string {
|
||||
}
|
||||
|
||||
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
|
||||
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
||||
if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
}
|
||||
// Use cloud-platform scope for Vertex AI.
|
||||
|
||||
@@ -76,6 +76,9 @@ func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected responses-only registry model to use /responses")
|
||||
}
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
|
||||
t.Fatal("expected responses-only registry model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
|
||||
@@ -83,15 +86,25 @@ func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *test
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "github-copilot-test-client"
|
||||
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
|
||||
ID: "gpt-5.4",
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
}})
|
||||
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
|
||||
{
|
||||
ID: "gpt-5.4",
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.4-mini",
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||
}
|
||||
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
|
||||
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
|
||||
@@ -30,12 +30,20 @@ const (
|
||||
gitLabChatEndpoint = "/api/v4/chat/completions"
|
||||
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
|
||||
gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming"
|
||||
gitLabContext1MBeta = "context-1m-2025-08-07"
|
||||
gitLabNativeUserAgent = "CLIProxyAPIPlus/GitLab-Duo"
|
||||
)
|
||||
|
||||
type GitLabExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type gitLabCatalogModel struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
Provider string
|
||||
}
|
||||
|
||||
type gitLabPrompt struct {
|
||||
Instruction string
|
||||
FileName string
|
||||
@@ -53,6 +61,23 @@ type gitLabOpenAIStreamState struct {
|
||||
Finished bool
|
||||
}
|
||||
|
||||
var gitLabAgenticCatalog = []gitLabCatalogModel{
|
||||
{ID: "duo-chat-gpt-5-1", DisplayName: "GitLab Duo (GPT-5.1)", Provider: "openai"},
|
||||
{ID: "duo-chat-opus-4-6", DisplayName: "GitLab Duo (Claude Opus 4.6)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-opus-4-5", DisplayName: "GitLab Duo (Claude Opus 4.5)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-sonnet-4-6", DisplayName: "GitLab Duo (Claude Sonnet 4.6)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-sonnet-4-5", DisplayName: "GitLab Duo (Claude Sonnet 4.5)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-gpt-5-mini", DisplayName: "GitLab Duo (GPT-5 Mini)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-2", DisplayName: "GitLab Duo (GPT-5.2)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-2-codex", DisplayName: "GitLab Duo (GPT-5.2 Codex)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-codex", DisplayName: "GitLab Duo (GPT-5 Codex)", Provider: "openai"},
|
||||
{ID: "duo-chat-haiku-4-5", DisplayName: "GitLab Duo (Claude Haiku 4.5)", Provider: "anthropic"},
|
||||
}
|
||||
|
||||
var gitLabModelAliases = map[string]string{
|
||||
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
|
||||
}
|
||||
|
||||
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
|
||||
return &GitLabExecutor{cfg: cfg}
|
||||
}
|
||||
@@ -249,12 +274,12 @@ func (e *GitLabExecutor) nativeGateway(
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, req.Model); ok {
|
||||
nativeReq := req
|
||||
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
||||
return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true
|
||||
}
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, req.Model); ok {
|
||||
nativeReq := req
|
||||
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
||||
return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true
|
||||
@@ -263,10 +288,10 @@ func (e *GitLabExecutor) nativeGateway(
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, ""); ok {
|
||||
return NewClaudeExecutor(e.cfg), nativeAuth
|
||||
}
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, ""); ok {
|
||||
return NewCodexExecutor(e.cfg), nativeAuth
|
||||
}
|
||||
return nil, nil
|
||||
@@ -664,7 +689,7 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
if auth != nil {
|
||||
util.ApplyCustomHeadersFromAttrs(req, auth.Attributes)
|
||||
}
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
for key, value := range gitLabGatewayHeaders(auth, "") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -672,34 +697,40 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
}
|
||||
}
|
||||
|
||||
func gitLabGatewayHeaders(auth *cliproxyauth.Auth) map[string]string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := auth.Metadata["duo_gateway_headers"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
func gitLabGatewayHeaders(auth *cliproxyauth.Auth, targetProvider string) map[string]string {
|
||||
out := make(map[string]string)
|
||||
switch typed := raw.(type) {
|
||||
case map[string]string:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key != "" && value != "" {
|
||||
out[key] = value
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
raw, ok := auth.Metadata["duo_gateway_headers"]
|
||||
if ok {
|
||||
switch typed := raw.(type) {
|
||||
case map[string]string:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key != "" && value != "" {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
strValue := strings.TrimSpace(fmt.Sprint(value))
|
||||
if strValue != "" {
|
||||
out[key] = strValue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
strValue := strings.TrimSpace(fmt.Sprint(value))
|
||||
if strValue != "" {
|
||||
out[key] = strValue
|
||||
}
|
||||
}
|
||||
if _, ok := out["User-Agent"]; !ok {
|
||||
out["User-Agent"] = gitLabNativeUserAgent
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(targetProvider), "openai") {
|
||||
if _, ok := out["anthropic-beta"]; !ok {
|
||||
out["anthropic-beta"] = gitLabContext1MBeta
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
@@ -989,8 +1020,8 @@ func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64)
|
||||
return promptTokens, int64(completionCount)
|
||||
}
|
||||
|
||||
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesAnthropicGateway(auth) {
|
||||
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesAnthropicGateway(auth, requestedModel) {
|
||||
return nil, false
|
||||
}
|
||||
baseURL := gitLabAnthropicGatewayBaseURL(auth)
|
||||
@@ -1006,7 +1037,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
|
||||
}
|
||||
nativeAuth.Attributes["api_key"] = token
|
||||
nativeAuth.Attributes["base_url"] = baseURL
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
nativeAuth.Attributes["gitlab_duo_force_context_1m"] = "true"
|
||||
for key, value := range gitLabGatewayHeaders(auth, "anthropic") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1015,8 +1047,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
|
||||
return nativeAuth, true
|
||||
}
|
||||
|
||||
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesOpenAIGateway(auth) {
|
||||
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesOpenAIGateway(auth, requestedModel) {
|
||||
return nil, false
|
||||
}
|
||||
baseURL := gitLabOpenAIGatewayBaseURL(auth)
|
||||
@@ -1032,7 +1064,7 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
|
||||
}
|
||||
nativeAuth.Attributes["api_key"] = token
|
||||
nativeAuth.Attributes["base_url"] = baseURL
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
for key, value := range gitLabGatewayHeaders(auth, "openai") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1041,34 +1073,41 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
|
||||
return nativeAuth, true
|
||||
}
|
||||
|
||||
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth) bool {
|
||||
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
provider = inferGitLabProviderFromModel(modelName)
|
||||
}
|
||||
provider := gitLabGatewayProvider(auth, requestedModel)
|
||||
return provider == "anthropic" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
||||
}
|
||||
|
||||
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth) bool {
|
||||
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
provider = inferGitLabProviderFromModel(modelName)
|
||||
}
|
||||
provider := gitLabGatewayProvider(auth, requestedModel)
|
||||
return provider == "openai" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
||||
}
|
||||
|
||||
func gitLabGatewayProvider(auth *cliproxyauth.Auth, requestedModel string) string {
|
||||
modelName := strings.TrimSpace(gitLabResolvedModel(auth, requestedModel))
|
||||
if provider := inferGitLabProviderFromModel(modelName); provider != "" {
|
||||
return provider
|
||||
}
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
provider = inferGitLabProviderFromModel(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
func inferGitLabProviderFromModel(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
switch {
|
||||
@@ -1151,6 +1190,9 @@ func gitLabBaseURL(auth *cliproxyauth.Auth) string {
|
||||
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
|
||||
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
|
||||
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
|
||||
if mapped, ok := gitLabModelAliases[strings.ToLower(requested)]; ok && strings.TrimSpace(mapped) != "" {
|
||||
return mapped
|
||||
}
|
||||
return requested
|
||||
}
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
@@ -1277,8 +1319,8 @@ func gitLabAuthKind(method string) string {
|
||||
}
|
||||
|
||||
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
||||
models := make([]*registry.ModelInfo, 0, 4)
|
||||
seen := make(map[string]struct{}, 4)
|
||||
models := make([]*registry.ModelInfo, 0, len(gitLabAgenticCatalog)+4)
|
||||
seen := make(map[string]struct{}, len(gitLabAgenticCatalog)+4)
|
||||
addModel := func(id, displayName, provider string) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
@@ -1302,6 +1344,18 @@ func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
||||
}
|
||||
|
||||
addModel("gitlab-duo", "GitLab Duo", "gitlab")
|
||||
for _, model := range gitLabAgenticCatalog {
|
||||
addModel(model.ID, model.DisplayName, model.Provider)
|
||||
}
|
||||
for alias, upstream := range gitLabModelAliases {
|
||||
target := strings.TrimSpace(upstream)
|
||||
displayName := "GitLab Duo Alias"
|
||||
provider := strings.TrimSpace(inferGitLabProviderFromModel(target))
|
||||
if provider != "" {
|
||||
displayName = fmt.Sprintf("GitLab Duo Alias (%s)", provider)
|
||||
}
|
||||
addModel(alias, displayName, provider)
|
||||
}
|
||||
if auth == nil {
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -217,6 +217,69 @@ func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotBetaHeader = r.Header.Get("anthropic-beta")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "duo-chat-gpt-5-codex",
|
||||
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotBetaHeader != gitLabContext1MBeta {
|
||||
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
|
||||
}
|
||||
if gotUserAgent != gitLabNativeUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||
}
|
||||
if gotModel != "duo-chat-gpt-5-codex" {
|
||||
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
|
||||
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
@@ -251,13 +314,12 @@ func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
ID: "gitlab-auth.json",
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"oauth_client_secret": "client-secret",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -397,9 +459,11 @@ func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotPath, gotBetaHeader, gotUserAgent string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotBetaHeader = r.Header.Get("Anthropic-Beta")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: message_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||
@@ -441,6 +505,12 @@ func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
|
||||
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
|
||||
}
|
||||
if gotUserAgent != gitLabNativeUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
||||
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type codexCache struct {
|
||||
type CodexCache struct {
|
||||
ID string
|
||||
Expire time.Time
|
||||
}
|
||||
@@ -13,7 +13,7 @@ type codexCache struct {
|
||||
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
|
||||
// Protected by codexCacheMu. Entries expire after 1 hour.
|
||||
var (
|
||||
codexCacheMap = make(map[string]codexCache)
|
||||
codexCacheMap = make(map[string]CodexCache)
|
||||
codexCacheMu sync.RWMutex
|
||||
)
|
||||
|
||||
@@ -50,20 +50,20 @@ func purgeExpiredCodexCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
|
||||
func getCodexCache(key string) (codexCache, bool) {
|
||||
// GetCodexCache retrieves a cached entry, returning ok=false if not found or expired.
|
||||
func GetCodexCache(key string) (CodexCache, bool) {
|
||||
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||
codexCacheMu.RLock()
|
||||
cache, ok := codexCacheMap[key]
|
||||
codexCacheMu.RUnlock()
|
||||
if !ok || cache.Expire.Before(time.Now()) {
|
||||
return codexCache{}, false
|
||||
return CodexCache{}, false
|
||||
}
|
||||
return cache, true
|
||||
}
|
||||
|
||||
// setCodexCache stores a cache entry.
|
||||
func setCodexCache(key string, cache codexCache) {
|
||||
// SetCodexCache stores a cache entry.
|
||||
func SetCodexCache(key string, cache CodexCache) {
|
||||
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||
codexCacheMu.Lock()
|
||||
codexCacheMap[key] = cache
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@@ -32,7 +32,7 @@ var (
|
||||
claudeDeviceProfileCacheMu sync.RWMutex
|
||||
claudeDeviceProfileCacheCleanupOnce sync.Once
|
||||
|
||||
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
|
||||
ClaudeDeviceProfileBeforeCandidateStore func(ClaudeDeviceProfile)
|
||||
)
|
||||
|
||||
type claudeCLIVersion struct {
|
||||
@@ -63,29 +63,43 @@ func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
|
||||
}
|
||||
}
|
||||
|
||||
type claudeDeviceProfile struct {
|
||||
type ClaudeDeviceProfile struct {
|
||||
UserAgent string
|
||||
PackageVersion string
|
||||
RuntimeVersion string
|
||||
OS string
|
||||
Arch string
|
||||
Version claudeCLIVersion
|
||||
HasVersion bool
|
||||
version claudeCLIVersion
|
||||
hasVersion bool
|
||||
}
|
||||
|
||||
type claudeDeviceProfileCacheEntry struct {
|
||||
profile claudeDeviceProfile
|
||||
profile ClaudeDeviceProfile
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
||||
func ClaudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
||||
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||
return false
|
||||
}
|
||||
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
|
||||
}
|
||||
|
||||
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
||||
func ResetClaudeDeviceProfileCache() {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func MapStainlessOS() string {
|
||||
return mapStainlessOS()
|
||||
}
|
||||
|
||||
func MapStainlessArch() string {
|
||||
return mapStainlessArch()
|
||||
}
|
||||
|
||||
func defaultClaudeDeviceProfile(cfg *config.Config) ClaudeDeviceProfile {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if strings.TrimSpace(cfgVal) != "" {
|
||||
return strings.TrimSpace(cfgVal)
|
||||
@@ -98,7 +112,7 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
profile := claudeDeviceProfile{
|
||||
profile := ClaudeDeviceProfile{
|
||||
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
|
||||
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
|
||||
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
|
||||
@@ -106,8 +120,8 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
||||
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
|
||||
}
|
||||
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||
profile.Version = version
|
||||
profile.HasVersion = true
|
||||
profile.version = version
|
||||
profile.hasVersion = true
|
||||
}
|
||||
return profile
|
||||
}
|
||||
@@ -162,17 +176,17 @@ func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
|
||||
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
|
||||
}
|
||||
|
||||
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
|
||||
if candidate.UserAgent == "" || !candidate.HasVersion {
|
||||
func shouldUpgradeClaudeDeviceProfile(candidate, current ClaudeDeviceProfile) bool {
|
||||
if candidate.UserAgent == "" || !candidate.hasVersion {
|
||||
return false
|
||||
}
|
||||
if current.UserAgent == "" || !current.HasVersion {
|
||||
if current.UserAgent == "" || !current.hasVersion {
|
||||
return true
|
||||
}
|
||||
return candidate.Version.Compare(current.Version) > 0
|
||||
return candidate.version.Compare(current.version) > 0
|
||||
}
|
||||
|
||||
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||
func pinClaudeDeviceProfilePlatform(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
|
||||
profile.OS = baseline.OS
|
||||
profile.Arch = baseline.Arch
|
||||
return profile
|
||||
@@ -180,38 +194,38 @@ func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claud
|
||||
|
||||
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
|
||||
// baseline platform and enforces the baseline software fingerprint as a floor.
|
||||
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||
func normalizeClaudeDeviceProfile(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
|
||||
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
|
||||
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
||||
if profile.UserAgent == "" || !profile.hasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
||||
profile.UserAgent = baseline.UserAgent
|
||||
profile.PackageVersion = baseline.PackageVersion
|
||||
profile.RuntimeVersion = baseline.RuntimeVersion
|
||||
profile.Version = baseline.Version
|
||||
profile.HasVersion = baseline.HasVersion
|
||||
profile.version = baseline.version
|
||||
profile.hasVersion = baseline.hasVersion
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
|
||||
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (ClaudeDeviceProfile, bool) {
|
||||
if headers == nil {
|
||||
return claudeDeviceProfile{}, false
|
||||
return ClaudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
||||
version, ok := parseClaudeCLIVersion(userAgent)
|
||||
if !ok {
|
||||
return claudeDeviceProfile{}, false
|
||||
return ClaudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
profile := claudeDeviceProfile{
|
||||
profile := ClaudeDeviceProfile{
|
||||
UserAgent: userAgent,
|
||||
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
|
||||
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
|
||||
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
|
||||
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
|
||||
Version: version,
|
||||
HasVersion: true,
|
||||
version: version,
|
||||
hasVersion: true,
|
||||
}
|
||||
return profile, true
|
||||
}
|
||||
@@ -263,7 +277,7 @@ func purgeExpiredClaudeDeviceProfiles() {
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
|
||||
func ResolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) ClaudeDeviceProfile {
|
||||
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
|
||||
|
||||
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
|
||||
@@ -283,8 +297,8 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
|
||||
claudeDeviceProfileCacheMu.RUnlock()
|
||||
|
||||
if hasCandidate {
|
||||
if claudeDeviceProfileBeforeCandidateStore != nil {
|
||||
claudeDeviceProfileBeforeCandidateStore(candidate)
|
||||
if ClaudeDeviceProfileBeforeCandidateStore != nil {
|
||||
ClaudeDeviceProfileBeforeCandidateStore(candidate)
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
@@ -324,7 +338,7 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
|
||||
return baseline
|
||||
}
|
||||
|
||||
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
|
||||
func ApplyClaudeDeviceProfileHeaders(r *http.Request, profile ClaudeDeviceProfile) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
@@ -344,7 +358,17 @@ func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfil
|
||||
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
||||
}
|
||||
|
||||
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||
// DefaultClaudeVersion returns the version string (e.g. "2.1.63") from the
|
||||
// current baseline device profile. It extracts the version from the User-Agent.
|
||||
func DefaultClaudeVersion(cfg *config.Config) string {
|
||||
profile := defaultClaudeDeviceProfile(cfg)
|
||||
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||
return strconv.Itoa(version.major) + "." + strconv.Itoa(version.minor) + "." + strconv.Itoa(version.patch)
|
||||
}
|
||||
return "2.1.63"
|
||||
}
|
||||
|
||||
func ApplyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
@@ -18,9 +18,9 @@ type SensitiveWordMatcher struct {
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
// buildSensitiveWordMatcher compiles a regex from the word list.
|
||||
// BuildSensitiveWordMatcher compiles a regex from the word list.
|
||||
// Words are sorted by length (longest first) for proper matching.
|
||||
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
|
||||
func BuildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
|
||||
if len(words) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -81,9 +81,9 @@ func (m *SensitiveWordMatcher) obfuscateText(text string) string {
|
||||
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
|
||||
}
|
||||
|
||||
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
|
||||
// ObfuscateSensitiveWords processes the payload and obfuscates sensitive words
|
||||
// in system blocks and message content.
|
||||
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||
func ObfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||
if matcher == nil || matcher.regex == nil {
|
||||
return payload
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@@ -28,9 +28,17 @@ func isValidUserID(userID string) bool {
|
||||
return userIDPattern.MatchString(userID)
|
||||
}
|
||||
|
||||
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
|
||||
func GenerateFakeUserID() string {
|
||||
return generateFakeUserID()
|
||||
}
|
||||
|
||||
func IsValidUserID(userID string) bool {
|
||||
return isValidUserID(userID)
|
||||
}
|
||||
|
||||
// ShouldCloak determines if request should be cloaked based on config and client User-Agent.
|
||||
// Returns true if cloaking should be applied.
|
||||
func shouldCloak(cloakMode string, userAgent string) bool {
|
||||
func ShouldCloak(cloakMode string, userAgent string) bool {
|
||||
switch strings.ToLower(cloakMode) {
|
||||
case "always":
|
||||
return true
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"html"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -19,13 +20,14 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
|
||||
apiRequestKey = "API_REQUEST"
|
||||
apiResponseKey = "API_RESPONSE"
|
||||
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
|
||||
apiRequestKey = "API_REQUEST"
|
||||
apiResponseKey = "API_RESPONSE"
|
||||
apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE"
|
||||
)
|
||||
|
||||
// upstreamRequestLog captures the outbound upstream request details for logging.
|
||||
type upstreamRequestLog struct {
|
||||
// UpstreamRequestLog captures the outbound upstream request details for logging.
|
||||
type UpstreamRequestLog struct {
|
||||
URL string
|
||||
Method string
|
||||
Headers http.Header
|
||||
@@ -46,11 +48,12 @@ type upstreamAttempt struct {
|
||||
headersWritten bool
|
||||
bodyStarted bool
|
||||
bodyHasContent bool
|
||||
prevWasSSEEvent bool
|
||||
errorWritten bool
|
||||
}
|
||||
|
||||
// recordAPIRequest stores the upstream request metadata in Gin context for request logging.
|
||||
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
|
||||
// RecordAPIRequest stores the upstream request metadata in Gin context for request logging.
|
||||
func RecordAPIRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
@@ -96,8 +99,8 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
|
||||
updateAggregatedRequest(ginCtx, attempts)
|
||||
}
|
||||
|
||||
// recordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
|
||||
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
// RecordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
|
||||
func RecordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
@@ -122,8 +125,8 @@ func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status i
|
||||
updateAggregatedResponse(ginCtx, attempts)
|
||||
}
|
||||
|
||||
// recordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
|
||||
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
|
||||
// RecordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
|
||||
func RecordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
|
||||
if cfg == nil || !cfg.RequestLog || err == nil {
|
||||
return
|
||||
}
|
||||
@@ -147,8 +150,8 @@ func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error)
|
||||
updateAggregatedResponse(ginCtx, attempts)
|
||||
}
|
||||
|
||||
// appendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
|
||||
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
|
||||
// AppendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
|
||||
func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
@@ -173,15 +176,157 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
||||
attempt.response.WriteString("Body:\n")
|
||||
attempt.bodyStarted = true
|
||||
}
|
||||
currentChunkIsSSEEvent := bytes.HasPrefix(data, []byte("event:"))
|
||||
currentChunkIsSSEData := bytes.HasPrefix(data, []byte("data:"))
|
||||
if attempt.bodyHasContent {
|
||||
attempt.response.WriteString("\n\n")
|
||||
separator := "\n\n"
|
||||
if attempt.prevWasSSEEvent && currentChunkIsSSEData {
|
||||
separator = "\n"
|
||||
}
|
||||
attempt.response.WriteString(separator)
|
||||
}
|
||||
attempt.response.WriteString(string(data))
|
||||
attempt.bodyHasContent = true
|
||||
attempt.prevWasSSEEvent = currentChunkIsSSEEvent
|
||||
|
||||
updateAggregatedResponse(ginCtx, attempts)
|
||||
}
|
||||
|
||||
// RecordAPIWebsocketRequest stores an upstream websocket request event in Gin context.
|
||||
func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
builder.WriteString("Event: api.websocket.request\n")
|
||||
if info.URL != "" {
|
||||
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
|
||||
}
|
||||
if auth := formatAuthInfo(info); auth != "" {
|
||||
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
|
||||
}
|
||||
builder.WriteString("Headers:\n")
|
||||
writeHeaders(builder, info.Headers)
|
||||
builder.WriteString("\nBody:\n")
|
||||
if len(info.Body) > 0 {
|
||||
builder.Write(info.Body)
|
||||
} else {
|
||||
builder.WriteString("<empty>")
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
|
||||
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||
}
|
||||
|
||||
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
|
||||
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
builder.WriteString("Event: api.websocket.handshake\n")
|
||||
if status > 0 {
|
||||
builder.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||
}
|
||||
builder.WriteString("Headers:\n")
|
||||
writeHeaders(builder, headers)
|
||||
builder.WriteString("\n")
|
||||
|
||||
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||
}
|
||||
|
||||
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
|
||||
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
RecordAPIRequest(ctx, cfg, info)
|
||||
RecordAPIResponseMetadata(ctx, cfg, status, headers)
|
||||
AppendAPIResponseChunk(ctx, cfg, body)
|
||||
}
|
||||
|
||||
// WebsocketUpgradeRequestURL converts a websocket URL back to its HTTP handshake URL for logging.
|
||||
func WebsocketUpgradeRequestURL(rawURL string) string {
|
||||
trimmedURL := strings.TrimSpace(rawURL)
|
||||
if trimmedURL == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(trimmedURL)
|
||||
if err != nil {
|
||||
return trimmedURL
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "ws":
|
||||
parsed.Scheme = "http"
|
||||
case "wss":
|
||||
parsed.Scheme = "https"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
// AppendAPIWebsocketResponse stores an upstream websocket response frame in Gin context.
|
||||
func AppendAPIWebsocketResponse(ctx context.Context, cfg *config.Config, payload []byte) {
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(payload)
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
markAPIResponseTimestamp(ginCtx)
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
builder.WriteString("Event: api.websocket.response\n")
|
||||
builder.Write(data)
|
||||
builder.WriteString("\n")
|
||||
|
||||
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||
}
|
||||
|
||||
// RecordAPIWebsocketError stores an upstream websocket error event in Gin context.
|
||||
func RecordAPIWebsocketError(ctx context.Context, cfg *config.Config, stage string, err error) {
|
||||
if cfg == nil || !cfg.RequestLog || err == nil {
|
||||
return
|
||||
}
|
||||
ginCtx := ginContextFrom(ctx)
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
markAPIResponseTimestamp(ginCtx)
|
||||
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||
builder.WriteString("Event: api.websocket.error\n")
|
||||
if trimmed := strings.TrimSpace(stage); trimmed != "" {
|
||||
builder.WriteString(fmt.Sprintf("Stage: %s\n", trimmed))
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
|
||||
|
||||
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||
}
|
||||
|
||||
func ginContextFrom(ctx context.Context) *gin.Context {
|
||||
ginCtx, _ := ctx.Value("gin").(*gin.Context)
|
||||
return ginCtx
|
||||
@@ -259,6 +404,40 @@ func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt)
|
||||
ginCtx.Set(apiResponseKey, []byte(builder.String()))
|
||||
}
|
||||
|
||||
func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) {
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(chunk)
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
combined := make([]byte, 0, len(existingBytes)+len(data)+2)
|
||||
combined = append(combined, existingBytes...)
|
||||
if !bytes.HasSuffix(existingBytes, []byte("\n")) {
|
||||
combined = append(combined, '\n')
|
||||
}
|
||||
combined = append(combined, '\n')
|
||||
combined = append(combined, data...)
|
||||
ginCtx.Set(apiWebsocketTimelineKey, combined)
|
||||
return
|
||||
}
|
||||
}
|
||||
ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data))
|
||||
}
|
||||
|
||||
func markAPIResponseTimestamp(ginCtx *gin.Context) {
|
||||
if ginCtx == nil {
|
||||
return
|
||||
}
|
||||
if _, exists := ginCtx.Get("API_RESPONSE_TIMESTAMP"); exists {
|
||||
return
|
||||
}
|
||||
ginCtx.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||
}
|
||||
|
||||
func writeHeaders(builder *strings.Builder, headers http.Header) {
|
||||
if builder == nil {
|
||||
return
|
||||
@@ -285,7 +464,7 @@ func writeHeaders(builder *strings.Builder, headers http.Header) {
|
||||
}
|
||||
}
|
||||
|
||||
func formatAuthInfo(info upstreamRequestLog) string {
|
||||
func formatAuthInfo(info UpstreamRequestLog) string {
|
||||
var parts []string
|
||||
if trimmed := strings.TrimSpace(info.Provider); trimmed != "" {
|
||||
parts = append(parts, fmt.Sprintf("provider=%s", trimmed))
|
||||
@@ -321,7 +500,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
func summarizeErrorBody(contentType string, body []byte) string {
|
||||
func SummarizeErrorBody(contentType string, body []byte) string {
|
||||
isHTML := strings.Contains(strings.ToLower(contentType), "text/html")
|
||||
if !isHTML {
|
||||
trimmed := bytes.TrimSpace(bytes.ToLower(body))
|
||||
@@ -379,7 +558,7 @@ func extractJSONErrorMessage(body []byte) string {
|
||||
|
||||
// logWithRequestID returns a logrus Entry with request_id field populated from context.
|
||||
// If no request ID is found in context, it returns the standard logger.
|
||||
func logWithRequestID(ctx context.Context) *log.Entry {
|
||||
func LogWithRequestID(ctx context.Context) *log.Entry {
|
||||
if ctx == nil {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||
// ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
||||
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||
// against the original payload when provided. requestedModel carries the client-visible
|
||||
// model name before alias resolution so payload rules can target aliases precisely.
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
@@ -244,7 +244,7 @@ func payloadRawValue(value any) ([]byte, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||
func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||
fallback = strings.TrimSpace(fallback)
|
||||
if len(opts.Metadata) == 0 {
|
||||
return fallback
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -19,7 +19,7 @@ var (
|
||||
httpClientCacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||
// NewProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||
// 1. Use auth.ProxyURL if configured (highest priority)
|
||||
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
||||
// 3. Use RoundTripper from context if neither are configured
|
||||
@@ -34,7 +34,7 @@ var (
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client with configured proxy or transport
|
||||
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
// Priority 1: Use auth.ProxyURL if configured
|
||||
var proxyURL string
|
||||
if auth != nil {
|
||||
@@ -46,23 +46,18 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
|
||||
// Build cache key from proxy URL (empty string for no proxy)
|
||||
cacheKey := proxyURL
|
||||
|
||||
// Check cache first
|
||||
httpClientCacheMutex.RLock()
|
||||
if cachedClient, ok := httpClientCache[cacheKey]; ok {
|
||||
httpClientCacheMutex.RUnlock()
|
||||
// Return a wrapper with the requested timeout but shared transport
|
||||
if timeout > 0 {
|
||||
return &http.Client{
|
||||
Transport: cachedClient.Transport,
|
||||
Timeout: timeout,
|
||||
// If we have a proxy URL configured, try cache first to reuse TCP/TLS connections.
|
||||
if proxyURL != "" {
|
||||
httpClientCacheMutex.RLock()
|
||||
if cachedClient, ok := httpClientCache[proxyURL]; ok {
|
||||
httpClientCacheMutex.RUnlock()
|
||||
if timeout > 0 {
|
||||
return &http.Client{Transport: cachedClient.Transport, Timeout: timeout}
|
||||
}
|
||||
return cachedClient
|
||||
}
|
||||
return cachedClient
|
||||
httpClientCacheMutex.RUnlock()
|
||||
}
|
||||
httpClientCacheMutex.RUnlock()
|
||||
|
||||
// Create new client
|
||||
httpClient := &http.Client{}
|
||||
@@ -77,7 +72,7 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
httpClient.Transport = transport
|
||||
// Cache the client
|
||||
httpClientCacheMutex.Lock()
|
||||
httpClientCache[cacheKey] = httpClient
|
||||
httpClientCache[proxyURL] = httpClient
|
||||
httpClientCacheMutex.Unlock()
|
||||
return httpClient
|
||||
}
|
||||
@@ -90,13 +85,6 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
httpClient.Transport = rt
|
||||
}
|
||||
|
||||
// Cache the client for no-proxy case
|
||||
if proxyURL == "" {
|
||||
httpClientCacheMutex.Lock()
|
||||
httpClientCache[cacheKey] = httpClient
|
||||
httpClientCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
return httpClient
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := newProxyAwareHTTPClient(
|
||||
client := NewProxyAwareHTTPClient(
|
||||
context.Background(),
|
||||
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||
92
internal/runtime/executor/helps/session_id_cache.go
Normal file
92
internal/runtime/executor/helps/session_id_cache.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type sessionIDCacheEntry struct {
|
||||
value string
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
sessionIDCache = make(map[string]sessionIDCacheEntry)
|
||||
sessionIDCacheMu sync.RWMutex
|
||||
sessionIDCacheCleanupOnce sync.Once
|
||||
)
|
||||
|
||||
const (
|
||||
sessionIDTTL = time.Hour
|
||||
sessionIDCacheCleanupPeriod = 15 * time.Minute
|
||||
)
|
||||
|
||||
func startSessionIDCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(sessionIDCacheCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredSessionIDs()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredSessionIDs() {
|
||||
now := time.Now()
|
||||
sessionIDCacheMu.Lock()
|
||||
for key, entry := range sessionIDCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(sessionIDCache, key)
|
||||
}
|
||||
}
|
||||
sessionIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func sessionIDCacheKey(apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(apiKey))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// CachedSessionID returns a stable session UUID per apiKey, refreshing the TTL on each access.
|
||||
func CachedSessionID(apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
sessionIDCacheCleanupOnce.Do(startSessionIDCacheCleanup)
|
||||
|
||||
key := sessionIDCacheKey(apiKey)
|
||||
now := time.Now()
|
||||
|
||||
sessionIDCacheMu.RLock()
|
||||
entry, ok := sessionIDCache[key]
|
||||
valid := ok && entry.value != "" && entry.expire.After(now)
|
||||
sessionIDCacheMu.RUnlock()
|
||||
if valid {
|
||||
sessionIDCacheMu.Lock()
|
||||
entry = sessionIDCache[key]
|
||||
if entry.value != "" && entry.expire.After(now) {
|
||||
entry.expire = now.Add(sessionIDTTL)
|
||||
sessionIDCache[key] = entry
|
||||
sessionIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
sessionIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
newID := uuid.New().String()
|
||||
|
||||
sessionIDCacheMu.Lock()
|
||||
entry, ok = sessionIDCache[key]
|
||||
if !ok || entry.value == "" || !entry.expire.After(now) {
|
||||
entry.value = newID
|
||||
}
|
||||
entry.expire = now.Add(sessionIDTTL)
|
||||
sessionIDCache[key] = entry
|
||||
sessionIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"
|
||||
@@ -1,9 +1,7 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -11,100 +9,80 @@ import (
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
// tokenizerCache stores tokenizer instances to avoid repeated creation
|
||||
// tokenizerCache stores tokenizer instances to avoid repeated creation.
|
||||
var tokenizerCache sync.Map
|
||||
|
||||
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
|
||||
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
|
||||
type TokenizerWrapper struct {
|
||||
Codec tokenizer.Codec
|
||||
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
|
||||
type adjustedTokenizer struct {
|
||||
tokenizer.Codec
|
||||
adjustmentFactor float64
|
||||
}
|
||||
|
||||
// Count returns the token count with adjustment factor applied
|
||||
func (tw *TokenizerWrapper) Count(text string) (int, error) {
|
||||
func (tw *adjustedTokenizer) Count(text string) (int, error) {
|
||||
count, err := tw.Codec.Count(text)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
|
||||
return int(float64(count) * tw.AdjustmentFactor), nil
|
||||
if tw.adjustmentFactor > 0 && tw.adjustmentFactor != 1.0 {
|
||||
return int(float64(count) * tw.adjustmentFactor), nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// getTokenizer returns a cached tokenizer for the given model.
|
||||
// This improves performance by avoiding repeated tokenizer creation.
|
||||
func getTokenizer(model string) (*TokenizerWrapper, error) {
|
||||
// Check cache first
|
||||
if cached, ok := tokenizerCache.Load(model); ok {
|
||||
return cached.(*TokenizerWrapper), nil
|
||||
// TokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
// For Claude-like models, it applies an adjustment factor since tiktoken may underestimate token counts.
|
||||
func TokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
if cached, ok := tokenizerCache.Load(sanitized); ok {
|
||||
return cached.(tokenizer.Codec), nil
|
||||
}
|
||||
|
||||
// Cache miss, create new tokenizer
|
||||
wrapper, err := tokenizerForModel(model)
|
||||
enc, err := tokenizerForModel(sanitized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache (use LoadOrStore to handle race conditions)
|
||||
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
|
||||
return actual.(*TokenizerWrapper), nil
|
||||
actual, _ := tokenizerCache.LoadOrStore(sanitized, enc)
|
||||
return actual.(tokenizer.Codec), nil
|
||||
}
|
||||
|
||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
|
||||
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
func tokenizerForModel(sanitized string) (tokenizer.Codec, error) {
|
||||
if sanitized == "" {
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
}
|
||||
|
||||
// Claude models use cl100k_base with 1.1 adjustment factor
|
||||
// because tiktoken may underestimate Claude's actual token count
|
||||
// Claude models use cl100k_base with an adjustment factor because tiktoken may underestimate.
|
||||
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
|
||||
return &adjustedTokenizer{Codec: enc, adjustmentFactor: 1.1}, nil
|
||||
}
|
||||
|
||||
var enc tokenizer.Codec
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case sanitized == "":
|
||||
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.2"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
|
||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4)
|
||||
return tokenizer.ForModel(tokenizer.GPT4)
|
||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
case strings.HasPrefix(sanitized, "o1"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.O1)
|
||||
return tokenizer.ForModel(tokenizer.O1)
|
||||
case strings.HasPrefix(sanitized, "o3"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.O3)
|
||||
return tokenizer.ForModel(tokenizer.O3)
|
||||
case strings.HasPrefix(sanitized, "o4"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
|
||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
||||
default:
|
||||
enc, err = tokenizer.Get(tokenizer.O200kBase)
|
||||
return tokenizer.Get(tokenizer.O200kBase)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
|
||||
}
|
||||
|
||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
// CountOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||
func CountOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
@@ -128,22 +106,15 @@ func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
|
||||
// This handles Claude's message format with system, messages, and tools.
|
||||
// Image tokens are estimated based on image dimensions when available.
|
||||
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
// CountClaudeChatTokens approximates prompt tokens for Claude API chat payloads.
|
||||
func CountClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
@@ -153,185 +124,25 @@ func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
segments := make([]string, 0, 32)
|
||||
imageTokens := 0
|
||||
|
||||
// Collect system prompt (can be string or array of content blocks)
|
||||
collectClaudeSystem(root.Get("system"), &segments)
|
||||
|
||||
// Collect messages
|
||||
collectClaudeMessages(root.Get("messages"), &segments)
|
||||
|
||||
// Collect tools
|
||||
collectClaudeContent(root.Get("system"), &segments, &imageTokens)
|
||||
collectClaudeMessages(root.Get("messages"), &segments, &imageTokens)
|
||||
collectClaudeTools(root.Get("tools"), &segments)
|
||||
|
||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
if joined == "" {
|
||||
return 0, nil
|
||||
return int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
return int64(count + imageTokens), nil
|
||||
}
|
||||
|
||||
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
|
||||
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
|
||||
|
||||
// extractImageTokens extracts image token estimates from placeholder text.
|
||||
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
|
||||
func extractImageTokens(text string) int {
|
||||
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
|
||||
total := 0
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
if tokens, err := strconv.Atoi(match[1]); err == nil {
|
||||
total += tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||
func estimateImageTokens(width, height float64) int {
|
||||
if width <= 0 || height <= 0 {
|
||||
// No valid dimensions, use default estimate (medium-sized image)
|
||||
return 1000
|
||||
}
|
||||
|
||||
tokens := int(width * height / 750)
|
||||
|
||||
// Apply bounds
|
||||
if tokens < 85 {
|
||||
tokens = 85
|
||||
}
|
||||
if tokens > 1590 {
|
||||
tokens = 1590
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// collectClaudeSystem extracts text from Claude's system field.
|
||||
// System can be a string or an array of content blocks.
|
||||
func collectClaudeSystem(system gjson.Result, segments *[]string) {
|
||||
if !system.Exists() {
|
||||
return
|
||||
}
|
||||
if system.Type == gjson.String {
|
||||
addIfNotEmpty(segments, system.String())
|
||||
return
|
||||
}
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, block gjson.Result) bool {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType == "text" || blockType == "" {
|
||||
addIfNotEmpty(segments, block.Get("text").String())
|
||||
}
|
||||
// Also handle plain string blocks
|
||||
if block.Type == gjson.String {
|
||||
addIfNotEmpty(segments, block.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeMessages extracts text from Claude's messages array.
|
||||
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
collectClaudeContent(message.Get("content"), segments)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// collectClaudeContent extracts text from Claude's content field.
|
||||
// Content can be a string or an array of content blocks.
|
||||
// For images, estimates token count based on dimensions when available.
|
||||
func collectClaudeContent(content gjson.Result, segments *[]string) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image":
|
||||
// Estimate image tokens based on dimensions if available
|
||||
source := part.Get("source")
|
||||
if source.Exists() {
|
||||
width := source.Get("width").Float()
|
||||
height := source.Get("height").Float()
|
||||
if width > 0 && height > 0 {
|
||||
tokens := estimateImageTokens(width, height)
|
||||
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
|
||||
} else {
|
||||
// No dimensions available, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
} else {
|
||||
// No source info, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
case "tool_use":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
addIfNotEmpty(segments, input.Raw)
|
||||
}
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||
collectClaudeContent(part.Get("content"), segments)
|
||||
case "thinking":
|
||||
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||
default:
|
||||
// For unknown types, try to extract any text content
|
||||
if part.Type == gjson.String {
|
||||
addIfNotEmpty(segments, part.String())
|
||||
} else if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeTools extracts text from Claude's tools array.
|
||||
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return
|
||||
}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||
addIfNotEmpty(segments, inputSchema.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
func buildOpenAIUsageJSON(count int64) []byte {
|
||||
// BuildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
func BuildOpenAIUsageJSON(count int64) []byte {
|
||||
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
|
||||
}
|
||||
|
||||
@@ -390,6 +201,10 @@ func collectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||
}
|
||||
}
|
||||
|
||||
func CollectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||
collectOpenAIContent(content, segments)
|
||||
}
|
||||
|
||||
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
|
||||
if !calls.Exists() || !calls.IsArray() {
|
||||
return
|
||||
@@ -487,6 +302,98 @@ func appendToolPayload(tool gjson.Result, segments *[]string) {
|
||||
}
|
||||
}
|
||||
|
||||
func collectClaudeMessages(messages gjson.Result, segments *[]string, imageTokens *int) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
collectClaudeContent(message.Get("content"), segments, imageTokens)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func collectClaudeContent(content gjson.Result, segments *[]string, imageTokens *int) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image":
|
||||
source := part.Get("source")
|
||||
width := source.Get("width").Float()
|
||||
height := source.Get("height").Float()
|
||||
if imageTokens != nil {
|
||||
*imageTokens += estimateImageTokens(width, height)
|
||||
}
|
||||
case "tool_use":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
addIfNotEmpty(segments, input.Raw)
|
||||
}
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||
collectClaudeContent(part.Get("content"), segments, imageTokens)
|
||||
case "thinking":
|
||||
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||
default:
|
||||
if part.Type == gjson.String {
|
||||
addIfNotEmpty(segments, part.String())
|
||||
} else if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, content.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return
|
||||
}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||
addIfNotEmpty(segments, inputSchema.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||
func estimateImageTokens(width, height float64) int {
|
||||
if width <= 0 || height <= 0 {
|
||||
// No valid dimensions, use default estimate (medium-sized image).
|
||||
return 1000
|
||||
}
|
||||
|
||||
tokens := int(width * height / 750)
|
||||
if tokens < 85 {
|
||||
return 85
|
||||
}
|
||||
if tokens > 1590 {
|
||||
return 1590
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func addIfNotEmpty(segments *[]string, value string) {
|
||||
if segments == nil {
|
||||
return
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type usageReporter struct {
|
||||
type UsageReporter struct {
|
||||
provider string
|
||||
model string
|
||||
authID string
|
||||
@@ -26,9 +26,9 @@ type usageReporter struct {
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
|
||||
apiKey := apiKeyFromContext(ctx)
|
||||
reporter := &usageReporter{
|
||||
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
|
||||
apiKey := APIKeyFromContext(ctx)
|
||||
reporter := &UsageReporter{
|
||||
provider: provider,
|
||||
model: model,
|
||||
requestedAt: time.Now(),
|
||||
@@ -42,24 +42,24 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox
|
||||
return reporter
|
||||
}
|
||||
|
||||
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
|
||||
func (r *UsageReporter) Publish(ctx context.Context, detail usage.Detail) {
|
||||
r.publishWithOutcome(ctx, detail, false)
|
||||
}
|
||||
|
||||
func (r *usageReporter) publishFailure(ctx context.Context) {
|
||||
func (r *UsageReporter) PublishFailure(ctx context.Context) {
|
||||
r.publishWithOutcome(ctx, usage.Detail{}, true)
|
||||
}
|
||||
|
||||
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
|
||||
func (r *UsageReporter) TrackFailure(ctx context.Context, errPtr *error) {
|
||||
if r == nil || errPtr == nil {
|
||||
return
|
||||
}
|
||||
if *errPtr != nil {
|
||||
r.publishFailure(ctx)
|
||||
r.PublishFailure(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
|
||||
func (r *UsageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
|
||||
// It is safe to call multiple times; only the first call wins due to once.Do.
|
||||
// This is used to ensure request counting even when upstream responses do not
|
||||
// include any usage fields (tokens), especially for streaming paths.
|
||||
func (r *usageReporter) ensurePublished(ctx context.Context) {
|
||||
func (r *UsageReporter) EnsurePublished(ctx context.Context) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
||||
func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
||||
if r == nil {
|
||||
return usage.Record{Detail: detail, Failed: failed}
|
||||
}
|
||||
@@ -108,7 +108,7 @@ func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Reco
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageReporter) latency() time.Duration {
|
||||
func (r *UsageReporter) latency() time.Duration {
|
||||
if r == nil || r.requestedAt.IsZero() {
|
||||
return 0
|
||||
}
|
||||
@@ -119,7 +119,7 @@ func (r *usageReporter) latency() time.Duration {
|
||||
return latency
|
||||
}
|
||||
|
||||
func apiKeyFromContext(ctx context.Context) string {
|
||||
func APIKeyFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseCodexUsage(data []byte) (usage.Detail, bool) {
|
||||
func ParseCodexUsage(data []byte) (usage.Detail, bool) {
|
||||
usageNode := gjson.ParseBytes(data).Get("response.usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}, false
|
||||
@@ -203,7 +203,7 @@ func parseCodexUsage(data []byte) (usage.Detail, bool) {
|
||||
return detail, true
|
||||
}
|
||||
|
||||
func parseOpenAIUsage(data []byte) usage.Detail {
|
||||
func ParseOpenAIUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}
|
||||
@@ -238,7 +238,7 @@ func parseOpenAIUsage(data []byte) usage.Detail {
|
||||
return detail
|
||||
}
|
||||
|
||||
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
func ParseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
@@ -247,59 +247,40 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
|
||||
inputNode := usageNode.Get("prompt_tokens")
|
||||
if !inputNode.Exists() {
|
||||
inputNode = usageNode.Get("input_tokens")
|
||||
}
|
||||
outputNode := usageNode.Get("completion_tokens")
|
||||
if !outputNode.Exists() {
|
||||
outputNode = usageNode.Get("output_tokens")
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: usageNode.Get("prompt_tokens").Int(),
|
||||
OutputTokens: usageNode.Get("completion_tokens").Int(),
|
||||
InputTokens: inputNode.Int(),
|
||||
OutputTokens: outputNode.Int(),
|
||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||
}
|
||||
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
|
||||
|
||||
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
|
||||
if !cached.Exists() {
|
||||
cached = usageNode.Get("input_tokens_details.cached_tokens")
|
||||
}
|
||||
if cached.Exists() {
|
||||
detail.CachedTokens = cached.Int()
|
||||
}
|
||||
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
||||
|
||||
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
|
||||
if !reasoning.Exists() {
|
||||
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
|
||||
}
|
||||
if reasoning.Exists() {
|
||||
detail.ReasoningTokens = reasoning.Int()
|
||||
}
|
||||
return detail, true
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
|
||||
detail := usage.Detail{
|
||||
InputTokens: usageNode.Get("input_tokens").Int(),
|
||||
OutputTokens: usageNode.Get("output_tokens").Int(),
|
||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
|
||||
}
|
||||
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
|
||||
detail.CachedTokens = cached.Int()
|
||||
}
|
||||
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
||||
detail.ReasoningTokens = reasoning.Int()
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
return parseOpenAIResponsesUsageDetail(usageNode)
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
usageNode := gjson.GetBytes(payload, "usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
return parseOpenAIResponsesUsageDetail(usageNode), true
|
||||
}
|
||||
|
||||
func parseClaudeUsage(data []byte) usage.Detail {
|
||||
func ParseClaudeUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}
|
||||
@@ -317,7 +298,7 @@ func parseClaudeUsage(data []byte) usage.Detail {
|
||||
return detail
|
||||
}
|
||||
|
||||
func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
func ParseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
@@ -352,7 +333,7 @@ func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
|
||||
return detail
|
||||
}
|
||||
|
||||
func parseGeminiCLIUsage(data []byte) usage.Detail {
|
||||
func ParseGeminiCLIUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data)
|
||||
node := usageNode.Get("response.usageMetadata")
|
||||
if !node.Exists() {
|
||||
@@ -364,7 +345,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseGeminiUsage(data []byte) usage.Detail {
|
||||
func ParseGeminiUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data)
|
||||
node := usageNode.Get("usageMetadata")
|
||||
if !node.Exists() {
|
||||
@@ -376,7 +357,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
func ParseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
@@ -391,7 +372,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
func ParseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
@@ -406,7 +387,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
func parseAntigravityUsage(data []byte) usage.Detail {
|
||||
func ParseAntigravityUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data)
|
||||
node := usageNode.Get("response.usageMetadata")
|
||||
if !node.Exists() {
|
||||
@@ -421,7 +402,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
func ParseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
@@ -590,6 +571,10 @@ func isStopChunkWithoutUsage(jsonBytes []byte) bool {
|
||||
return !hasUsageMetadata(jsonBytes)
|
||||
}
|
||||
|
||||
func JSONPayload(line []byte) []byte {
|
||||
return jsonPayload(line)
|
||||
}
|
||||
|
||||
func jsonPayload(line []byte) []byte {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 {
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
||||
detail := parseOpenAIUsage(data)
|
||||
detail := ParseOpenAIUsage(data)
|
||||
if detail.InputTokens != 1 {
|
||||
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
|
||||
}
|
||||
@@ -29,7 +29,7 @@ func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||
|
||||
func TestParseOpenAIUsageResponses(t *testing.T) {
|
||||
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
|
||||
detail := parseOpenAIUsage(data)
|
||||
detail := ParseOpenAIUsage(data)
|
||||
if detail.InputTokens != 10 {
|
||||
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
||||
reporter := &usageReporter{
|
||||
reporter := &UsageReporter{
|
||||
provider: "openai",
|
||||
model: "gpt-5.4",
|
||||
requestedAt: time.Now().Add(-1500 * time.Millisecond),
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@@ -49,7 +49,7 @@ func userIDCacheKey(apiKey string) string {
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func cachedUserID(apiKey string) string {
|
||||
func CachedUserID(apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return generateFakeUserID()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package executor
|
||||
package helps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -14,8 +14,8 @@ func resetUserIDCache() {
|
||||
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
first := cachedUserID("api-key-1")
|
||||
second := cachedUserID("api-key-1")
|
||||
first := CachedUserID("api-key-1")
|
||||
second := CachedUserID("api-key-1")
|
||||
|
||||
if first == "" {
|
||||
t.Fatal("expected generated user_id to be non-empty")
|
||||
@@ -28,7 +28,7 @@ func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
||||
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
expiredID := cachedUserID("api-key-expired")
|
||||
expiredID := CachedUserID("api-key-expired")
|
||||
cacheKey := userIDCacheKey("api-key-expired")
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache[cacheKey] = userIDCacheEntry{
|
||||
@@ -37,7 +37,7 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
|
||||
newID := cachedUserID("api-key-expired")
|
||||
newID := CachedUserID("api-key-expired")
|
||||
if newID == expiredID {
|
||||
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
|
||||
}
|
||||
@@ -49,8 +49,8 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
first := cachedUserID("api-key-1")
|
||||
second := cachedUserID("api-key-2")
|
||||
first := CachedUserID("api-key-1")
|
||||
second := CachedUserID("api-key-2")
|
||||
|
||||
if first == second {
|
||||
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
|
||||
@@ -61,7 +61,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
key := "api-key-renew"
|
||||
id := cachedUserID(key)
|
||||
id := CachedUserID(key)
|
||||
cacheKey := userIDCacheKey(key)
|
||||
|
||||
soon := time.Now()
|
||||
@@ -72,7 +72,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
|
||||
if refreshed := cachedUserID(key); refreshed != id {
|
||||
if refreshed := CachedUserID(key); refreshed != id {
|
||||
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
|
||||
}
|
||||
|
||||
188
internal/runtime/executor/helps/utls_client.go
Normal file
188
internal/runtime/executor/helps/utls_client.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
|
||||
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||
type utlsRoundTripper struct {
|
||||
mu sync.Mutex
|
||||
connections map[string]*http2.ClientConn
|
||||
pending map[string]*sync.Cond
|
||||
dialer proxy.Dialer
|
||||
}
|
||||
|
||||
func newUtlsRoundTripper(proxyURL string) *utlsRoundTripper {
|
||||
var dialer proxy.Dialer = proxy.Direct
|
||||
if proxyURL != "" {
|
||||
proxyDialer, mode, errBuild := proxyutil.BuildDialer(proxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyURL, errBuild)
|
||||
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||
dialer = proxyDialer
|
||||
}
|
||||
}
|
||||
return &utlsRoundTripper{
|
||||
connections: make(map[string]*http2.ClientConn),
|
||||
pending: make(map[string]*sync.Cond),
|
||||
dialer: dialer,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
t.mu.Lock()
|
||||
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
t.mu.Unlock()
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
if cond, ok := t.pending[host]; ok {
|
||||
cond.Wait()
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
t.mu.Unlock()
|
||||
return h2Conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
cond := sync.NewCond(&t.mu)
|
||||
t.pending[host] = cond
|
||||
t.mu.Unlock()
|
||||
|
||||
h2Conn, err := t.createConnection(host, addr)
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
delete(t.pending, host)
|
||||
cond.Broadcast()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.connections[host] = h2Conn
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := t.dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr := &http2.Transport{}
|
||||
h2Conn, err := tr.NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
hostname := req.URL.Hostname()
|
||||
port := req.URL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
addr := net.JoinHostPort(hostname, port)
|
||||
|
||||
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := h2Conn.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.mu.Lock()
|
||||
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
|
||||
delete(t.connections, hostname)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// anthropicHosts contains the hosts that should use utls Chrome TLS fingerprint.
|
||||
var anthropicHosts = map[string]struct{}{
|
||||
"api.anthropic.com": {},
|
||||
}
|
||||
|
||||
// fallbackRoundTripper uses utls for Anthropic HTTPS hosts and falls back to
|
||||
// standard transport for all other requests (non-HTTPS or non-Anthropic hosts).
|
||||
type fallbackRoundTripper struct {
|
||||
utls *utlsRoundTripper
|
||||
fallback http.RoundTripper
|
||||
}
|
||||
|
||||
func (f *fallbackRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Scheme == "https" {
|
||||
if _, ok := anthropicHosts[strings.ToLower(req.URL.Hostname())]; ok {
|
||||
return f.utls.RoundTrip(req)
|
||||
}
|
||||
}
|
||||
return f.fallback.RoundTrip(req)
|
||||
}
|
||||
|
||||
// NewUtlsHTTPClient creates an HTTP client using utls Chrome TLS fingerprint.
|
||||
// Use this for Claude API requests to match real Claude Code's TLS behavior.
|
||||
// Falls back to standard transport for non-HTTPS requests.
|
||||
func NewUtlsHTTPClient(cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
var proxyURL string
|
||||
if auth != nil {
|
||||
proxyURL = strings.TrimSpace(auth.ProxyURL)
|
||||
}
|
||||
if proxyURL == "" && cfg != nil {
|
||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
|
||||
utlsRT := newUtlsRoundTripper(proxyURL)
|
||||
|
||||
var standardTransport http.RoundTripper = &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
}
|
||||
if proxyURL != "" {
|
||||
if transport := buildProxyTransport(proxyURL); transport != nil {
|
||||
standardTransport = transport
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &fallbackRoundTripper{
|
||||
utls: utlsRT,
|
||||
fallback: standardTransport,
|
||||
},
|
||||
}
|
||||
if timeout > 0 {
|
||||
client.Timeout = timeout
|
||||
}
|
||||
return client
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -66,7 +67,7 @@ func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -86,8 +87,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
baseURL = iflowauth.DefaultAPIBaseURL
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
@@ -106,8 +107,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
|
||||
@@ -116,13 +117,18 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
return resp, err
|
||||
}
|
||||
applyIFlowHeaders(httpReq, apiKey, false)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -134,10 +140,10 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
@@ -145,25 +151,25 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
log.Errorf("iflow executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||
// Ensure usage is recorded even if upstream omits usage metadata.
|
||||
reporter.ensurePublished(ctx)
|
||||
reporter.EnsurePublished(ctx)
|
||||
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
@@ -189,8 +195,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
baseURL = iflowauth.DefaultAPIBaseURL
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
@@ -214,8 +220,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
body = ensureToolsArray(body)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
|
||||
@@ -224,13 +230,18 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, err
|
||||
}
|
||||
applyIFlowHeaders(httpReq, apiKey, true)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -242,21 +253,21 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
data, _ := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("iflow executor: close response body error: %v", errClose)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
return nil, err
|
||||
}
|
||||
@@ -275,9 +286,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
@@ -285,12 +296,12 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
// Guarantee a usage record exists even if the stream never emitted usage data.
|
||||
reporter.ensurePublished(ctx)
|
||||
reporter.EnsurePublished(ctx)
|
||||
}()
|
||||
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
@@ -303,17 +314,17 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
enc, err := tokenizerForModel(baseModel)
|
||||
enc, err := helps.TokenizerForModel(baseModel)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := countOpenAIChatTokens(enc, body)
|
||||
count, err := helps.CountOpenAIChatTokens(enc, body)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
@@ -15,7 +15,9 @@ import (
|
||||
|
||||
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -45,6 +47,11 @@ func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,7 +67,7 @@ func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -76,8 +83,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
@@ -100,8 +107,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
@@ -113,13 +120,18 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -131,10 +143,10 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
@@ -142,21 +154,21 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
@@ -176,8 +188,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
@@ -204,8 +216,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -217,13 +229,18 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -235,17 +252,17 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -265,9 +282,9 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
@@ -279,8 +296,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -65,15 +66,15 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
baseURL, apiKey := e.resolveCredentials(auth)
|
||||
if baseURL == "" {
|
||||
@@ -95,8 +96,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
if opts.Alt == "responses/compact" {
|
||||
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
|
||||
translated = updated
|
||||
@@ -129,7 +130,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -141,10 +142,10 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
@@ -152,23 +153,23 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
|
||||
// Ensure we at least record the request even if upstream doesn't return usage
|
||||
reporter.ensurePublished(ctx)
|
||||
reporter.EnsurePublished(ctx)
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
@@ -179,8 +180,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
baseURL, apiKey := e.resolveCredentials(auth)
|
||||
if baseURL == "" {
|
||||
@@ -197,8 +198,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -232,7 +233,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -244,17 +245,17 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -274,9 +275,9 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
@@ -294,12 +295,12 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
// Ensure we record the request if no usage chunk was ever seen
|
||||
reporter.ensurePublished(ctx)
|
||||
reporter.EnsurePublished(ctx)
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
@@ -318,17 +319,17 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
enc, err := tokenizerForModel(modelForCounting)
|
||||
enc, err := helps.TokenizerForModel(modelForCounting)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := countOpenAIChatTokens(enc, translated)
|
||||
count, err := helps.CountOpenAIChatTokens(enc, translated)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
|
||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -23,7 +25,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
qwenUserAgent = "QwenCode/0.13.2 (darwin; arm64)"
|
||||
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||
)
|
||||
@@ -154,7 +156,7 @@ func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int,
|
||||
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||
cooldown := timeUntilNextDay()
|
||||
retryAfter = &cooldown
|
||||
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||
helps.LogWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||
}
|
||||
return errCode, retryAfter
|
||||
}
|
||||
@@ -202,7 +204,7 @@ func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -217,7 +219,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
authID = auth.ID
|
||||
}
|
||||
if err := checkQwenRateLimit(authID); err != nil {
|
||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -228,8 +230,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
baseURL = "https://portal.qwen.ai/v1"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
@@ -247,8 +249,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -256,12 +258,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, false)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -273,10 +280,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
@@ -284,23 +291,23 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
@@ -320,7 +327,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
authID = auth.ID
|
||||
}
|
||||
if err := checkQwenRateLimit(authID); err != nil {
|
||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -331,8 +338,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
baseURL = "https://portal.qwen.ai/v1"
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
@@ -357,8 +364,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -366,12 +373,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
applyQwenHeaders(httpReq, token, true)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
@@ -383,19 +395,19 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||
|
||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -415,9 +427,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
@@ -429,8 +441,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
@@ -449,17 +461,17 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
modelName = baseModel
|
||||
}
|
||||
|
||||
enc, err := tokenizerForModel(modelName)
|
||||
enc, err := helps.TokenizerForModel(modelName)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := countOpenAIChatTokens(enc, body)
|
||||
count, err := helps.CountOpenAIChatTokens(enc, body)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
@@ -508,16 +520,15 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
r.Header.Set("User-Agent", qwenUserAgent)
|
||||
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
|
||||
r.Header["X-DashScope-UserAgent"] = []string{qwenUserAgent}
|
||||
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
||||
r.Header.Set("Sec-Fetch-Mode", "cors")
|
||||
r.Header.Set("X-Stainless-Lang", "js")
|
||||
r.Header.Set("X-Stainless-Arch", "arm64")
|
||||
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
||||
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
||||
r.Header["X-DashScope-CacheControl"] = []string{"enable"}
|
||||
r.Header.Set("X-Stainless-Retry-Count", "0")
|
||||
r.Header.Set("X-Stainless-Os", "MacOS")
|
||||
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
||||
r.Header["X-DashScope-AuthType"] = []string{"qwen-oauth"}
|
||||
r.Header.Set("X-Stainless-Runtime", "node")
|
||||
|
||||
if stream {
|
||||
|
||||
@@ -446,6 +446,7 @@ func (s *GitTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
if email, ok := metadata["email"].(string); ok && email != "" {
|
||||
auth.Attributes["email"] = email
|
||||
}
|
||||
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -595,6 +595,7 @@ func (s *ObjectTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Aut
|
||||
LastRefreshedAt: time.Time{},
|
||||
NextRefreshAfter: time.Time{},
|
||||
}
|
||||
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -310,6 +310,7 @@ func (s *PostgresStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error)
|
||||
LastRefreshedAt: time.Time{},
|
||||
NextRefreshAfter: time.Time{},
|
||||
}
|
||||
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||
auths = append(auths, auth)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
|
||||
@@ -330,32 +330,45 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder parts for 'model' role to ensure thinking block is first
|
||||
// Reorder parts for 'model' role:
|
||||
// 1. Thinking parts first (Antigravity API requirement)
|
||||
// 2. Regular parts (text, inlineData, etc.)
|
||||
// 3. FunctionCall parts last
|
||||
//
|
||||
// Moving functionCall parts to the end prevents tool_use↔tool_result
|
||||
// pairing breakage: the Antigravity API internally splits model messages
|
||||
// at functionCall boundaries. If a text part follows a functionCall, the
|
||||
// split creates an extra assistant turn between tool_use and tool_result,
|
||||
// which Claude rejects with "tool_use ids were found without tool_result
|
||||
// blocks immediately after".
|
||||
if role == "model" {
|
||||
partsResult := gjson.GetBytes(clientContentJSON, "parts")
|
||||
if partsResult.IsArray() {
|
||||
parts := partsResult.Array()
|
||||
var thinkingParts []gjson.Result
|
||||
var otherParts []gjson.Result
|
||||
for _, part := range parts {
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingParts = append(thinkingParts, part)
|
||||
} else {
|
||||
otherParts = append(otherParts, part)
|
||||
}
|
||||
}
|
||||
if len(thinkingParts) > 0 {
|
||||
firstPartIsThinking := parts[0].Get("thought").Bool()
|
||||
if !firstPartIsThinking || len(thinkingParts) > 1 {
|
||||
var newParts []interface{}
|
||||
for _, p := range thinkingParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
if len(parts) > 1 {
|
||||
var thinkingParts []gjson.Result
|
||||
var regularParts []gjson.Result
|
||||
var functionCallParts []gjson.Result
|
||||
for _, part := range parts {
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingParts = append(thinkingParts, part)
|
||||
} else if part.Get("functionCall").Exists() {
|
||||
functionCallParts = append(functionCallParts, part)
|
||||
} else {
|
||||
regularParts = append(regularParts, part)
|
||||
}
|
||||
for _, p := range otherParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
||||
}
|
||||
var newParts []interface{}
|
||||
for _, p := range thinkingParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
for _, p := range regularParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
for _, p := range functionCallParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -361,6 +361,167 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderTextAfterFunctionCall(t *testing.T) {
|
||||
// Bug: text part after tool_use in an assistant message causes Antigravity
|
||||
// to split at functionCall boundary, creating an extra assistant turn that
|
||||
// breaks tool_use↔tool_result adjacency (upstream issue #989).
|
||||
// Fix: reorder parts so functionCall comes last.
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me check..."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_abc",
|
||||
"name": "Read",
|
||||
"input": {"file": "test.go"}
|
||||
},
|
||||
{"type": "text", "text": "Reading the file now"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_abc",
|
||||
"content": "file content"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("Expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Text parts should come before functionCall
|
||||
if parts[0].Get("text").String() != "Let me check..." {
|
||||
t.Errorf("Expected first text part first, got %s", parts[0].Raw)
|
||||
}
|
||||
if parts[1].Get("text").String() != "Reading the file now" {
|
||||
t.Errorf("Expected second text part second, got %s", parts[1].Raw)
|
||||
}
|
||||
if !parts[2].Get("functionCall").Exists() {
|
||||
t.Errorf("Expected functionCall last, got %s", parts[2].Raw)
|
||||
}
|
||||
if parts[2].Get("functionCall.name").String() != "Read" {
|
||||
t.Errorf("Expected functionCall name 'Read', got '%s'", parts[2].Get("functionCall.name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderParallelFunctionCalls(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Reading both files."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_1",
|
||||
"name": "Read",
|
||||
"input": {"file": "a.go"}
|
||||
},
|
||||
{"type": "text", "text": "And this one too."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_2",
|
||||
"name": "Read",
|
||||
"input": {"file": "b.go"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 4 {
|
||||
t.Fatalf("Expected 4 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
if parts[0].Get("text").String() != "Reading both files." {
|
||||
t.Errorf("Expected first text, got %s", parts[0].Raw)
|
||||
}
|
||||
if parts[1].Get("text").String() != "And this one too." {
|
||||
t.Errorf("Expected second text, got %s", parts[1].Raw)
|
||||
}
|
||||
if parts[2].Get("functionCall.name").String() != "Read" || parts[2].Get("functionCall.id").String() != "call_1" {
|
||||
t.Errorf("Expected fc1 third, got %s", parts[2].Raw)
|
||||
}
|
||||
if parts[3].Get("functionCall.name").String() != "Read" || parts[3].Get("functionCall.id").String() != "call_2" {
|
||||
t.Errorf("Expected fc2 fourth, got %s", parts[3].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderThinkingAndTextBeforeFunctionCall(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Let me think about this..."
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Before thinking"},
|
||||
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_xyz",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"}
|
||||
},
|
||||
{"type": "text", "text": "After tool call"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// contents.1 = assistant message (contents.0 = user)
|
||||
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
|
||||
if len(parts) != 4 {
|
||||
t.Fatalf("Expected 4 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Order: thinking → text → text → functionCall
|
||||
if !parts[0].Get("thought").Bool() {
|
||||
t.Error("First part should be thinking")
|
||||
}
|
||||
if parts[1].Get("functionCall").Exists() || parts[1].Get("thought").Bool() {
|
||||
t.Errorf("Second part should be text, got %s", parts[1].Raw)
|
||||
}
|
||||
if parts[2].Get("functionCall").Exists() || parts[2].Get("thought").Bool() {
|
||||
t.Errorf("Third part should be text, got %s", parts[2].Raw)
|
||||
}
|
||||
if !parts[3].Get("functionCall").Exists() {
|
||||
t.Errorf("Last part should be functionCall, got %s", parts[3].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
|
||||
@@ -284,12 +284,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
}
|
||||
|
||||
// Process the output array for content and function calls
|
||||
var toolCalls [][]byte
|
||||
outputResult := responseResult.Get("output")
|
||||
if outputResult.IsArray() {
|
||||
outputArray := outputResult.Array()
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls [][]byte
|
||||
|
||||
for _, outputItem := range outputArray {
|
||||
outputType := outputItem.Get("type").String()
|
||||
@@ -367,8 +367,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if statusResult := responseResult.Get("status"); statusResult.Exists() {
|
||||
status := statusResult.String()
|
||||
if status == "completed" {
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
@@ -31,8 +31,6 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// - []byte: The transformed request in Gemini CLI format.
|
||||
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := []byte(`{"contents":[]}`)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
@@ -146,13 +144,37 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
})
|
||||
}
|
||||
|
||||
// strip trailing model turn with unanswered function calls —
|
||||
// Gemini returns empty responses when the last turn is a model
|
||||
// functionCall with no corresponding user functionResponse.
|
||||
contents := gjson.GetBytes(out, "contents")
|
||||
if contents.Exists() && contents.IsArray() {
|
||||
arr := contents.Array()
|
||||
if len(arr) > 0 {
|
||||
last := arr[len(arr)-1]
|
||||
if last.Get("role").String() == "model" {
|
||||
hasFC := false
|
||||
last.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
hasFC = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if hasFC {
|
||||
out, _ = sjson.DeleteBytes(out, fmt.Sprintf("contents.%d", len(arr)-1))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tools
|
||||
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
|
||||
hasTools := false
|
||||
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
|
||||
inputSchemaResult := toolResult.Get("input_schema")
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
inputSchema := inputSchemaResult.Raw
|
||||
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
|
||||
tool := []byte(toolResult.Raw)
|
||||
var err error
|
||||
tool, err = sjson.DeleteBytes(tool, "input_schema")
|
||||
@@ -168,6 +190,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
tool, _ = sjson.DeleteBytes(tool, "type")
|
||||
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
||||
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
||||
tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming")
|
||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
|
||||
if !hasTools {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
type oaiToResponsesStateReasoning struct {
|
||||
ReasoningID string
|
||||
ReasoningData string
|
||||
OutputIndex int
|
||||
}
|
||||
type oaiToResponsesState struct {
|
||||
Seq int
|
||||
@@ -29,16 +31,19 @@ type oaiToResponsesState struct {
|
||||
MsgTextBuf map[int]*strings.Builder
|
||||
ReasoningBuf strings.Builder
|
||||
Reasonings []oaiToResponsesStateReasoning
|
||||
FuncArgsBuf map[int]*strings.Builder // index -> args
|
||||
FuncNames map[int]string // index -> name
|
||||
FuncCallIDs map[int]string // index -> call_id
|
||||
FuncArgsBuf map[string]*strings.Builder
|
||||
FuncNames map[string]string
|
||||
FuncCallIDs map[string]string
|
||||
FuncOutputIx map[string]int
|
||||
MsgOutputIx map[int]int
|
||||
NextOutputIx int
|
||||
// message item state per output index
|
||||
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
|
||||
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
|
||||
MsgItemDone map[int]bool // whether message done events were emitted
|
||||
// function item done state
|
||||
FuncArgsDone map[int]bool
|
||||
FuncItemDone map[int]bool
|
||||
FuncArgsDone map[string]bool
|
||||
FuncItemDone map[string]bool
|
||||
// usage aggregation
|
||||
PromptTokens int64
|
||||
CachedTokens int64
|
||||
@@ -60,15 +65,17 @@ func emitRespEvent(event string, payload []byte) []byte {
|
||||
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &oaiToResponsesState{
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncArgsBuf: make(map[string]*strings.Builder),
|
||||
FuncNames: make(map[string]string),
|
||||
FuncCallIDs: make(map[string]string),
|
||||
FuncOutputIx: make(map[string]int),
|
||||
MsgOutputIx: make(map[int]int),
|
||||
MsgTextBuf: make(map[int]*strings.Builder),
|
||||
MsgItemAdded: make(map[int]bool),
|
||||
MsgContentAdded: make(map[int]bool),
|
||||
MsgItemDone: make(map[int]bool),
|
||||
FuncArgsDone: make(map[int]bool),
|
||||
FuncItemDone: make(map[int]bool),
|
||||
FuncArgsDone: make(map[string]bool),
|
||||
FuncItemDone: make(map[string]bool),
|
||||
Reasonings: make([]oaiToResponsesStateReasoning, 0),
|
||||
}
|
||||
}
|
||||
@@ -125,6 +132,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
|
||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||
allocOutputIndex := func() int {
|
||||
ix := st.NextOutputIx
|
||||
st.NextOutputIx++
|
||||
return ix
|
||||
}
|
||||
toolStateKey := func(outputIndex, toolIndex int) string { return fmt.Sprintf("%d:%d", outputIndex, toolIndex) }
|
||||
var out [][]byte
|
||||
|
||||
if !st.Started {
|
||||
@@ -135,14 +148,17 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
st.ReasoningBuf.Reset()
|
||||
st.ReasoningID = ""
|
||||
st.ReasoningIndex = 0
|
||||
st.FuncArgsBuf = make(map[int]*strings.Builder)
|
||||
st.FuncNames = make(map[int]string)
|
||||
st.FuncCallIDs = make(map[int]string)
|
||||
st.FuncArgsBuf = make(map[string]*strings.Builder)
|
||||
st.FuncNames = make(map[string]string)
|
||||
st.FuncCallIDs = make(map[string]string)
|
||||
st.FuncOutputIx = make(map[string]int)
|
||||
st.MsgOutputIx = make(map[int]int)
|
||||
st.NextOutputIx = 0
|
||||
st.MsgItemAdded = make(map[int]bool)
|
||||
st.MsgContentAdded = make(map[int]bool)
|
||||
st.MsgItemDone = make(map[int]bool)
|
||||
st.FuncArgsDone = make(map[int]bool)
|
||||
st.FuncItemDone = make(map[int]bool)
|
||||
st.FuncArgsDone = make(map[string]bool)
|
||||
st.FuncItemDone = make(map[string]bool)
|
||||
st.PromptTokens = 0
|
||||
st.CachedTokens = 0
|
||||
st.CompletionTokens = 0
|
||||
@@ -185,7 +201,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text)
|
||||
out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
|
||||
|
||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text})
|
||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text, OutputIndex: st.ReasoningIndex})
|
||||
st.ReasoningID = ""
|
||||
}
|
||||
|
||||
@@ -201,10 +217,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
stopReasoning(st.ReasoningBuf.String())
|
||||
st.ReasoningBuf.Reset()
|
||||
}
|
||||
if _, exists := st.MsgOutputIx[idx]; !exists {
|
||||
st.MsgOutputIx[idx] = allocOutputIndex()
|
||||
}
|
||||
msgOutputIndex := st.MsgOutputIx[idx]
|
||||
if !st.MsgItemAdded[idx] {
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "output_index", msgOutputIndex)
|
||||
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||
st.MsgItemAdded[idx] = true
|
||||
@@ -213,7 +233,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
part, _ = sjson.SetBytes(part, "output_index", idx)
|
||||
part, _ = sjson.SetBytes(part, "output_index", msgOutputIndex)
|
||||
part, _ = sjson.SetBytes(part, "content_index", 0)
|
||||
out = append(out, emitRespEvent("response.content_part.added", part))
|
||||
st.MsgContentAdded[idx] = true
|
||||
@@ -222,7 +242,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", idx)
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", msgOutputIndex)
|
||||
msg, _ = sjson.SetBytes(msg, "content_index", 0)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", c.String())
|
||||
out = append(out, emitRespEvent("response.output_text.delta", msg))
|
||||
@@ -238,10 +258,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
// On first appearance, add reasoning item and part
|
||||
if st.ReasoningID == "" {
|
||||
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
|
||||
st.ReasoningIndex = idx
|
||||
st.ReasoningIndex = allocOutputIndex()
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "output_index", st.ReasoningIndex)
|
||||
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID)
|
||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||
@@ -269,6 +289,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
// Before emitting any function events, if a message is open for this index,
|
||||
// close its text/content to match Codex expected ordering.
|
||||
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
|
||||
msgOutputIndex := st.MsgOutputIx[idx]
|
||||
fullText := ""
|
||||
if b := st.MsgTextBuf[idx]; b != nil {
|
||||
fullText = b.String()
|
||||
@@ -276,7 +297,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
done, _ = sjson.SetBytes(done, "output_index", idx)
|
||||
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||
@@ -284,69 +305,72 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", idx)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
st.MsgItemDone[idx] = true
|
||||
}
|
||||
|
||||
// Only emit item.added once per tool call and preserve call_id across chunks.
|
||||
newCallID := tcs.Get("0.id").String()
|
||||
nameChunk := tcs.Get("0.function.name").String()
|
||||
if nameChunk != "" {
|
||||
st.FuncNames[idx] = nameChunk
|
||||
}
|
||||
existingCallID := st.FuncCallIDs[idx]
|
||||
effectiveCallID := existingCallID
|
||||
shouldEmitItem := false
|
||||
if existingCallID == "" && newCallID != "" {
|
||||
// First time seeing a valid call_id for this index
|
||||
effectiveCallID = newCallID
|
||||
st.FuncCallIDs[idx] = newCallID
|
||||
shouldEmitItem = true
|
||||
}
|
||||
|
||||
if shouldEmitItem && effectiveCallID != "" {
|
||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||
o, _ = sjson.SetBytes(o, "output_index", idx)
|
||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||
name := st.FuncNames[idx]
|
||||
o, _ = sjson.SetBytes(o, "item.name", name)
|
||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||
}
|
||||
|
||||
// Ensure args buffer exists for this index
|
||||
if st.FuncArgsBuf[idx] == nil {
|
||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||
}
|
||||
|
||||
// Append arguments delta if available and we have a valid call_id to reference
|
||||
if args := tcs.Get("0.function.arguments"); args.Exists() && args.String() != "" {
|
||||
// Prefer an already known call_id; fall back to newCallID if first time
|
||||
refCallID := st.FuncCallIDs[idx]
|
||||
if refCallID == "" {
|
||||
refCallID = newCallID
|
||||
tcs.ForEach(func(_, tc gjson.Result) bool {
|
||||
toolIndex := int(tc.Get("index").Int())
|
||||
key := toolStateKey(idx, toolIndex)
|
||||
newCallID := tc.Get("id").String()
|
||||
nameChunk := tc.Get("function.name").String()
|
||||
if nameChunk != "" {
|
||||
st.FuncNames[key] = nameChunk
|
||||
}
|
||||
if refCallID != "" {
|
||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||
ad, _ = sjson.SetBytes(ad, "output_index", idx)
|
||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||
|
||||
existingCallID := st.FuncCallIDs[key]
|
||||
effectiveCallID := existingCallID
|
||||
shouldEmitItem := false
|
||||
if existingCallID == "" && newCallID != "" {
|
||||
effectiveCallID = newCallID
|
||||
st.FuncCallIDs[key] = newCallID
|
||||
st.FuncOutputIx[key] = allocOutputIndex()
|
||||
shouldEmitItem = true
|
||||
}
|
||||
st.FuncArgsBuf[idx].WriteString(args.String())
|
||||
}
|
||||
|
||||
if shouldEmitItem && effectiveCallID != "" {
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||
o, _ = sjson.SetBytes(o, "output_index", outputIndex)
|
||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||
o, _ = sjson.SetBytes(o, "item.name", st.FuncNames[key])
|
||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||
}
|
||||
|
||||
if st.FuncArgsBuf[key] == nil {
|
||||
st.FuncArgsBuf[key] = &strings.Builder{}
|
||||
}
|
||||
|
||||
if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" {
|
||||
refCallID := st.FuncCallIDs[key]
|
||||
if refCallID == "" {
|
||||
refCallID = newCallID
|
||||
}
|
||||
if refCallID != "" {
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||
ad, _ = sjson.SetBytes(ad, "output_index", outputIndex)
|
||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||
}
|
||||
st.FuncArgsBuf[key].WriteString(args.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,15 +384,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
for i := range st.MsgItemAdded {
|
||||
idxs = append(idxs, i)
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
sort.Slice(idxs, func(i, j int) bool { return st.MsgOutputIx[idxs[i]] < st.MsgOutputIx[idxs[j]] })
|
||||
for _, i := range idxs {
|
||||
if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
|
||||
msgOutputIndex := st.MsgOutputIx[i]
|
||||
fullText := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
fullText = b.String()
|
||||
@@ -376,7 +395,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
done, _ = sjson.SetBytes(done, "output_index", i)
|
||||
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||
@@ -384,14 +403,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", i)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
@@ -407,43 +426,42 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
|
||||
// Emit function call done events for any active function calls
|
||||
if len(st.FuncCallIDs) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncCallIDs))
|
||||
for i := range st.FuncCallIDs {
|
||||
idxs = append(idxs, i)
|
||||
keys := make([]string, 0, len(st.FuncCallIDs))
|
||||
for key := range st.FuncCallIDs {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range idxs {
|
||||
callID := st.FuncCallIDs[i]
|
||||
if callID == "" || st.FuncItemDone[i] {
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
left := st.FuncOutputIx[keys[i]]
|
||||
right := st.FuncOutputIx[keys[j]]
|
||||
return left < right || (left == right && keys[i] < keys[j])
|
||||
})
|
||||
for _, key := range keys {
|
||||
callID := st.FuncCallIDs[key]
|
||||
if callID == "" || st.FuncItemDone[key] {
|
||||
continue
|
||||
}
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
args := "{}"
|
||||
if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 {
|
||||
if b := st.FuncArgsBuf[key]; b != nil && b.Len() > 0 {
|
||||
args = b.String()
|
||||
}
|
||||
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", i)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", outputIndex)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", outputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[i])
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[key])
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
st.FuncItemDone[i] = true
|
||||
st.FuncArgsDone[i] = true
|
||||
st.FuncItemDone[key] = true
|
||||
st.FuncArgsDone[key] = true
|
||||
}
|
||||
}
|
||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||
@@ -516,28 +534,21 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
// Build response.output using aggregated buffers
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
type completedOutputItem struct {
|
||||
index int
|
||||
raw []byte
|
||||
}
|
||||
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
|
||||
if len(st.Reasonings) > 0 {
|
||||
for _, r := range st.Reasonings {
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
|
||||
}
|
||||
}
|
||||
// Append message items in ascending index order
|
||||
if len(st.MsgItemAdded) > 0 {
|
||||
midxs := make([]int, 0, len(st.MsgItemAdded))
|
||||
for i := range st.MsgItemAdded {
|
||||
midxs = append(midxs, i)
|
||||
}
|
||||
for i := 0; i < len(midxs); i++ {
|
||||
for j := i + 1; j < len(midxs); j++ {
|
||||
if midxs[j] < midxs[i] {
|
||||
midxs[i], midxs[j] = midxs[j], midxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range midxs {
|
||||
txt := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
txt = b.String()
|
||||
@@ -545,37 +556,29 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", txt)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
|
||||
}
|
||||
}
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncArgsBuf))
|
||||
for i := range st.FuncArgsBuf {
|
||||
idxs = append(idxs, i)
|
||||
}
|
||||
// small-N sort without extra imports
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range idxs {
|
||||
for key := range st.FuncArgsBuf {
|
||||
args := ""
|
||||
if b := st.FuncArgsBuf[i]; b != nil {
|
||||
if b := st.FuncArgsBuf[key]; b != nil {
|
||||
args = b.String()
|
||||
}
|
||||
callID := st.FuncCallIDs[i]
|
||||
name := st.FuncNames[i]
|
||||
callID := st.FuncCallIDs[key]
|
||||
name := st.FuncNames[key]
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
|
||||
}
|
||||
}
|
||||
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
|
||||
for _, item := range outputItems {
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
|
||||
}
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Result) {
|
||||
t.Helper()
|
||||
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
if len(lines) < 2 {
|
||||
t.Fatalf("unexpected SSE chunk: %q", chunk)
|
||||
}
|
||||
|
||||
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||
if !gjson.Valid(dataLine) {
|
||||
t.Fatalf("invalid SSE data JSON: %q", dataLine)
|
||||
}
|
||||
return event, gjson.Parse(dataLine)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\",\"limit\":400,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
addedNames := map[string]string{}
|
||||
doneArgs := map[string]string{}
|
||||
doneNames := map[string]string{}
|
||||
outputItems := map[string]gjson.Result{}
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
addedNames[data.Get("item.call_id").String()] = data.Get("item.name").String()
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
doneArgs[callID] = data.Get("item.arguments").String()
|
||||
doneNames[callID] = data.Get("item.name").String()
|
||||
case "response.completed":
|
||||
output := data.Get("response.output")
|
||||
for _, item := range output.Array() {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
outputItems[item.Get("call_id").String()] = item
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(addedNames) != 2 {
|
||||
t.Fatalf("expected 2 function_call added events, got %d", len(addedNames))
|
||||
}
|
||||
if len(doneArgs) != 2 {
|
||||
t.Fatalf("expected 2 function_call done events, got %d", len(doneArgs))
|
||||
}
|
||||
|
||||
if addedNames["call_read"] != "read" {
|
||||
t.Fatalf("unexpected added name for call_read: %q", addedNames["call_read"])
|
||||
}
|
||||
if addedNames["call_glob"] != "glob" {
|
||||
t.Fatalf("unexpected added name for call_glob: %q", addedNames["call_glob"])
|
||||
}
|
||||
|
||||
if !gjson.Valid(doneArgs["call_read"]) {
|
||||
t.Fatalf("invalid JSON args for call_read: %q", doneArgs["call_read"])
|
||||
}
|
||||
if !gjson.Valid(doneArgs["call_glob"]) {
|
||||
t.Fatalf("invalid JSON args for call_glob: %q", doneArgs["call_glob"])
|
||||
}
|
||||
if strings.Contains(doneArgs["call_read"], "}{") {
|
||||
t.Fatalf("call_read args were concatenated: %q", doneArgs["call_read"])
|
||||
}
|
||||
if strings.Contains(doneArgs["call_glob"], "}{") {
|
||||
t.Fatalf("call_glob args were concatenated: %q", doneArgs["call_glob"])
|
||||
}
|
||||
|
||||
if doneNames["call_read"] != "read" {
|
||||
t.Fatalf("unexpected done name for call_read: %q", doneNames["call_read"])
|
||||
}
|
||||
if doneNames["call_glob"] != "glob" {
|
||||
t.Fatalf("unexpected done name for call_glob: %q", doneNames["call_glob"])
|
||||
}
|
||||
|
||||
if got := gjson.Get(doneArgs["call_read"], "filePath").String(); got != `C:\repo` {
|
||||
t.Fatalf("unexpected filePath for call_read: %q", got)
|
||||
}
|
||||
if got := gjson.Get(doneArgs["call_glob"], "path").String(); got != `C:\repo` {
|
||||
t.Fatalf("unexpected path for call_glob: %q", got)
|
||||
}
|
||||
if got := gjson.Get(doneArgs["call_glob"], "pattern").String(); got != "*.{yml,yaml}" {
|
||||
t.Fatalf("unexpected pattern for call_glob: %q", got)
|
||||
}
|
||||
|
||||
if len(outputItems) != 2 {
|
||||
t.Fatalf("expected 2 function_call items in response.output, got %d", len(outputItems))
|
||||
}
|
||||
if outputItems["call_read"].Get("name").String() != "read" {
|
||||
t.Fatalf("unexpected response.output name for call_read: %q", outputItems["call_read"].Get("name").String())
|
||||
}
|
||||
if outputItems["call_glob"].Get("name").String() != "glob" {
|
||||
t.Fatalf("unexpected response.output name for call_glob: %q", outputItems["call_glob"].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCallsUseDistinctOutputIndexes(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
type fcEvent struct {
|
||||
outputIndex int64
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
added := map[string]fcEvent{}
|
||||
done := map[string]fcEvent{}
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
added[callID] = fcEvent{
|
||||
outputIndex: data.Get("output_index").Int(),
|
||||
name: data.Get("item.name").String(),
|
||||
}
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
done[callID] = fcEvent{
|
||||
outputIndex: data.Get("output_index").Int(),
|
||||
name: data.Get("item.name").String(),
|
||||
arguments: data.Get("item.arguments").String(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(added) != 2 {
|
||||
t.Fatalf("expected 2 function_call added events, got %d", len(added))
|
||||
}
|
||||
if len(done) != 2 {
|
||||
t.Fatalf("expected 2 function_call done events, got %d", len(done))
|
||||
}
|
||||
|
||||
if added["call_choice0"].name != "glob" {
|
||||
t.Fatalf("unexpected added name for call_choice0: %q", added["call_choice0"].name)
|
||||
}
|
||||
if added["call_choice1"].name != "read" {
|
||||
t.Fatalf("unexpected added name for call_choice1: %q", added["call_choice1"].name)
|
||||
}
|
||||
if added["call_choice0"].outputIndex == added["call_choice1"].outputIndex {
|
||||
t.Fatalf("expected distinct output indexes for different choices, both got %d", added["call_choice0"].outputIndex)
|
||||
}
|
||||
|
||||
if !gjson.Valid(done["call_choice0"].arguments) {
|
||||
t.Fatalf("invalid JSON args for call_choice0: %q", done["call_choice0"].arguments)
|
||||
}
|
||||
if !gjson.Valid(done["call_choice1"].arguments) {
|
||||
t.Fatalf("invalid JSON args for call_choice1: %q", done["call_choice1"].arguments)
|
||||
}
|
||||
if done["call_choice0"].outputIndex == done["call_choice1"].outputIndex {
|
||||
t.Fatalf("expected distinct done output indexes for different choices, both got %d", done["call_choice0"].outputIndex)
|
||||
}
|
||||
if done["call_choice0"].name != "glob" {
|
||||
t.Fatalf("unexpected done name for call_choice0: %q", done["call_choice0"].name)
|
||||
}
|
||||
if done["call_choice1"].name != "read" {
|
||||
t.Fatalf("unexpected done name for call_choice1: %q", done["call_choice1"].name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndToolUseDistinctOutputIndexes(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var messageOutputIndex int64 = -1
|
||||
var toolOutputIndex int64 = -1
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
if ev != "response.output_item.added" {
|
||||
continue
|
||||
}
|
||||
switch data.Get("item.type").String() {
|
||||
case "message":
|
||||
if data.Get("item.id").String() == "msg_resp_mixed_0" {
|
||||
messageOutputIndex = data.Get("output_index").Int()
|
||||
}
|
||||
case "function_call":
|
||||
if data.Get("item.call_id").String() == "call_choice1" {
|
||||
toolOutputIndex = data.Get("output_index").Int()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messageOutputIndex < 0 {
|
||||
t.Fatal("did not find message output index")
|
||||
}
|
||||
if toolOutputIndex < 0 {
|
||||
t.Fatal("did not find tool output index")
|
||||
}
|
||||
if messageOutputIndex == toolOutputIndex {
|
||||
t.Fatalf("expected distinct output indexes for message and tool call, both got %d", messageOutputIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneAndCompletedOutputStayAscending(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var doneIndexes []int64
|
||||
var completedOrder []string
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() == "function_call" {
|
||||
doneIndexes = append(doneIndexes, data.Get("output_index").Int())
|
||||
}
|
||||
case "response.completed":
|
||||
for _, item := range data.Get("response.output").Array() {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
completedOrder = append(completedOrder, item.Get("call_id").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(doneIndexes) != 2 {
|
||||
t.Fatalf("expected 2 function_call done indexes, got %d", len(doneIndexes))
|
||||
}
|
||||
if doneIndexes[0] >= doneIndexes[1] {
|
||||
t.Fatalf("expected ascending done output indexes, got %v", doneIndexes)
|
||||
}
|
||||
if len(completedOrder) != 2 {
|
||||
t.Fatalf("expected 2 function_call items in completed output, got %d", len(completedOrder))
|
||||
}
|
||||
if completedOrder[0] != "call_glob" || completedOrder[1] != "call_read" {
|
||||
t.Fatalf("unexpected completed function_call order: %v", completedOrder)
|
||||
}
|
||||
}
|
||||
@@ -201,6 +201,7 @@ var zhStrings = map[string]string{
|
||||
"usage_output": "输出",
|
||||
"usage_cached": "缓存",
|
||||
"usage_reasoning": "思考",
|
||||
"usage_time": "时间",
|
||||
|
||||
// ── Logs ──
|
||||
"logs_title": "📋 日志",
|
||||
@@ -352,6 +353,7 @@ var enStrings = map[string]string{
|
||||
"usage_output": "Output",
|
||||
"usage_cached": "Cached",
|
||||
"usage_reasoning": "Reasoning",
|
||||
"usage_time": "Time",
|
||||
|
||||
// ── Logs ──
|
||||
"logs_title": "📋 Logs",
|
||||
|
||||
@@ -248,6 +248,9 @@ func (m usageTabModel) renderContent() string {
|
||||
|
||||
// Token type breakdown from details
|
||||
sb.WriteString(m.renderTokenBreakdown(stats))
|
||||
|
||||
// Latency breakdown from details
|
||||
sb.WriteString(m.renderLatencyBreakdown(stats))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -308,6 +311,57 @@ func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
|
||||
}
|
||||
|
||||
// renderLatencyBreakdown aggregates latency_ms from model details and displays avg/min/max.
|
||||
func (m usageTabModel) renderLatencyBreakdown(modelStats map[string]any) string {
|
||||
details, ok := modelStats["details"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
detailList, ok := details.([]any)
|
||||
if !ok || len(detailList) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var totalLatency int64
|
||||
var count int
|
||||
var minLatency, maxLatency int64
|
||||
first := true
|
||||
|
||||
for _, d := range detailList {
|
||||
dm, ok := d.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
latencyMs := int64(getFloat(dm, "latency_ms"))
|
||||
if latencyMs <= 0 {
|
||||
continue
|
||||
}
|
||||
totalLatency += latencyMs
|
||||
count++
|
||||
if first {
|
||||
minLatency = latencyMs
|
||||
maxLatency = latencyMs
|
||||
first = false
|
||||
} else {
|
||||
if latencyMs < minLatency {
|
||||
minLatency = latencyMs
|
||||
}
|
||||
if latencyMs > maxLatency {
|
||||
maxLatency = latencyMs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
avgLatency := totalLatency / int64(count)
|
||||
return fmt.Sprintf(" │ %s: avg %dms min %dms max %dms\n",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_time")),
|
||||
avgLatency, minLatency, maxLatency)
|
||||
}
|
||||
|
||||
// renderBarChart renders a simple ASCII horizontal bar chart.
|
||||
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
|
||||
if maxBarWidth < 10 {
|
||||
|
||||
134
internal/tui/usage_tab_test.go
Normal file
134
internal/tui/usage_tab_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRenderLatencyBreakdown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelStats map[string]any
|
||||
wantEmpty bool
|
||||
wantContains string
|
||||
}{
|
||||
{
|
||||
name: "no details",
|
||||
modelStats: map[string]any{},
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty details",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{},
|
||||
},
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "details with zero latency",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{
|
||||
map[string]any{
|
||||
"latency_ms": float64(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "single request with latency",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{
|
||||
map[string]any{
|
||||
"latency_ms": float64(1500),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: false,
|
||||
wantContains: "avg 1500ms min 1500ms max 1500ms",
|
||||
},
|
||||
{
|
||||
name: "multiple requests with varying latency",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{
|
||||
map[string]any{
|
||||
"latency_ms": float64(100),
|
||||
},
|
||||
map[string]any{
|
||||
"latency_ms": float64(200),
|
||||
},
|
||||
map[string]any{
|
||||
"latency_ms": float64(300),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: false,
|
||||
wantContains: "avg 200ms min 100ms max 300ms",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid latency values",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{
|
||||
map[string]any{
|
||||
"latency_ms": float64(500),
|
||||
},
|
||||
map[string]any{
|
||||
"latency_ms": float64(0),
|
||||
},
|
||||
map[string]any{
|
||||
"latency_ms": float64(1500),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: false,
|
||||
wantContains: "avg 1000ms min 500ms max 1500ms",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := usageTabModel{}
|
||||
result := m.renderLatencyBreakdown(tt.modelStats)
|
||||
|
||||
if tt.wantEmpty {
|
||||
if result != "" {
|
||||
t.Errorf("renderLatencyBreakdown() = %q, want empty string", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result == "" {
|
||||
t.Errorf("renderLatencyBreakdown() = empty, want non-empty string")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantContains != "" && !strings.Contains(result, tt.wantContains) {
|
||||
t.Errorf("renderLatencyBreakdown() = %q, want to contain %q", result, tt.wantContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTimeTranslations(t *testing.T) {
|
||||
prevLocale := CurrentLocale()
|
||||
t.Cleanup(func() {
|
||||
SetLocale(prevLocale)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
locale string
|
||||
want string
|
||||
}{
|
||||
{locale: "en", want: "Time"},
|
||||
{locale: "zh", want: "时间"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.locale, func(t *testing.T) {
|
||||
SetLocale(tt.locale)
|
||||
if got := T("usage_time"); got != tt.want {
|
||||
t.Fatalf("T(usage_time) = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -57,6 +57,12 @@ func GetProviderName(modelName string) []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// Fallback: if cursor provider has registered models, route unknown models to it.
|
||||
// Cursor acts as a universal proxy supporting multiple model families (Claude, GPT, Gemini, etc.).
|
||||
if models := registry.GetGlobalRegistry().GetAvailableModelsByProvider("cursor"); len(models) > 0 {
|
||||
return []string{"cursor"}
|
||||
}
|
||||
|
||||
return providers
|
||||
}
|
||||
|
||||
|
||||
@@ -80,6 +80,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
|
||||
}
|
||||
if oldCfg.QuotaExceeded.AntigravityCredits != newCfg.QuotaExceeded.AntigravityCredits {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.antigravity-credits: %t -> %t", oldCfg.QuotaExceeded.AntigravityCredits, newCfg.QuotaExceeded.AntigravityCredits))
|
||||
}
|
||||
|
||||
if oldCfg.Routing.Strategy != newCfg.Routing.Strategy {
|
||||
changes = append(changes, fmt.Sprintf("routing.strategy: %s -> %s", oldCfg.Routing.Strategy, newCfg.Routing.Strategy))
|
||||
|
||||
@@ -229,7 +229,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
MaxRetryCredentials: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false, AntigravityCredits: false},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x1"}},
|
||||
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
|
||||
@@ -253,7 +253,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
MaxRetryCredentials: 3,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true, AntigravityCredits: true},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
{APIKey: "c2"},
|
||||
@@ -297,6 +297,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
expectContains(t, details, "nonstream-keepalive-interval: 0 -> 5")
|
||||
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.antigravity-credits: false -> true")
|
||||
expectContains(t, details, "api-keys count: 1 -> 2")
|
||||
expectContains(t, details, "claude-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "codex-api-key count: 1 -> 2")
|
||||
@@ -320,7 +321,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
MaxRetryCredentials: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false, AntigravityCredits: false},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}},
|
||||
},
|
||||
@@ -374,7 +375,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
MaxRetryCredentials: 3,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true, AntigravityCredits: true},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
@@ -437,6 +438,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
expectContains(t, changes, "ws-auth: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.antigravity-credits: false -> true")
|
||||
expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new")
|
||||
expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new")
|
||||
|
||||
@@ -157,6 +157,7 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
|
||||
}
|
||||
}
|
||||
}
|
||||
coreauth.ApplyCustomHeadersFromMetadata(a)
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
// For codex auth files, extract plan_type from the JWT id_token.
|
||||
if provider == "codex" {
|
||||
@@ -233,6 +234,11 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
||||
if noteVal, hasNote := primary.Attributes["note"]; hasNote && noteVal != "" {
|
||||
attrs["note"] = noteVal
|
||||
}
|
||||
for k, v := range primary.Attributes {
|
||||
if strings.HasPrefix(k, "header:") && strings.TrimSpace(v) != "" {
|
||||
attrs[k] = v
|
||||
}
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
|
||||
@@ -69,10 +69,14 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
||||
|
||||
// Create a valid auth file
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"email": "test@example.com",
|
||||
"proxy_url": "http://proxy.local",
|
||||
"prefix": "test-prefix",
|
||||
"type": "claude",
|
||||
"email": "test@example.com",
|
||||
"proxy_url": "http://proxy.local",
|
||||
"prefix": "test-prefix",
|
||||
"headers": map[string]string{
|
||||
" X-Test ": " value ",
|
||||
"X-Empty": " ",
|
||||
},
|
||||
"disable_cooling": true,
|
||||
"request_retry": 2,
|
||||
}
|
||||
@@ -110,6 +114,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
if got := auths[0].Attributes["header:X-Test"]; got != "value" {
|
||||
t.Errorf("expected header:X-Test value, got %q", got)
|
||||
}
|
||||
if _, ok := auths[0].Attributes["header:X-Empty"]; ok {
|
||||
t.Errorf("expected header:X-Empty to be absent, got %q", auths[0].Attributes["header:X-Empty"])
|
||||
}
|
||||
if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v {
|
||||
t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"])
|
||||
}
|
||||
@@ -450,8 +460,9 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
||||
Prefix: "test-prefix",
|
||||
ProxyURL: "http://proxy.local",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
"header:X-Tra": "value",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
@@ -506,6 +517,9 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
||||
if v.Attributes["runtime_only"] != "true" {
|
||||
t.Error("expected runtime_only=true")
|
||||
}
|
||||
if got := v.Attributes["header:X-Tra"]; got != "value" {
|
||||
t.Errorf("expected virtual %d header:X-Tra %q, got %q", i, "value", got)
|
||||
}
|
||||
if v.Attributes["gemini_virtual_parent"] != "primary-id" {
|
||||
t.Errorf("expected gemini_virtual_parent=primary-id, got %s", v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
|
||||
@@ -136,6 +136,8 @@ type authAwareStreamExecutor struct {
|
||||
|
||||
type invalidJSONStreamExecutor struct{}
|
||||
|
||||
type splitResponsesEventStreamExecutor struct{}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -165,6 +167,36 @@ func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *corea
|
||||
}
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) Identifier() string { return "split-sse" }
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
ch := make(chan coreexecutor.StreamChunk, 2)
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed")}
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}")}
|
||||
close(ch)
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *splitResponsesEventStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -607,3 +639,52 @@ func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *
|
||||
t.Fatalf("expected terminal error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_AllowsSplitOpenAIResponsesSSEEventLines(t *testing.T) {
|
||||
executor := &splitResponsesEventStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "split-sse",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []string
|
||||
for chunk := range dataChan {
|
||||
got = append(got, string(chunk))
|
||||
}
|
||||
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
t.Fatalf("unexpected error: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 forwarded chunks, got %d: %#v", len(got), got)
|
||||
}
|
||||
if got[0] != "event: response.completed" {
|
||||
t.Fatalf("unexpected first chunk: %q", got[0])
|
||||
}
|
||||
expectedData := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
|
||||
if got[1] != expectedData {
|
||||
t.Fatalf("unexpected second chunk.\nGot: %q\nWant: %q", got[1], expectedData)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// gatewayHeaderPrefixes lists header name prefixes injected by known AI gateway
|
||||
// proxies. Claude Code's client-side telemetry detects these and reports the
|
||||
// gateway type, so we strip them from upstream responses to avoid detection.
|
||||
var gatewayHeaderPrefixes = []string{
|
||||
"x-litellm-",
|
||||
"helicone-",
|
||||
"x-portkey-",
|
||||
"cf-aig-",
|
||||
"x-kong-",
|
||||
"x-bt-",
|
||||
}
|
||||
|
||||
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
|
||||
// be forwarded by proxies, plus security-sensitive headers that should not leak.
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
@@ -40,6 +52,19 @@ func FilterUpstreamHeaders(src http.Header) http.Header {
|
||||
if _, scoped := connectionScoped[canonicalKey]; scoped {
|
||||
continue
|
||||
}
|
||||
// Strip headers injected by known AI gateway proxies to avoid
|
||||
// Claude Code client-side gateway detection.
|
||||
lowerKey := strings.ToLower(key)
|
||||
gatewayMatch := false
|
||||
for _, prefix := range gatewayHeaderPrefixes {
|
||||
if strings.HasPrefix(lowerKey, prefix) {
|
||||
gatewayMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if gatewayMatch {
|
||||
continue
|
||||
}
|
||||
dst[key] = values
|
||||
}
|
||||
if len(dst) == 0 {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package openai
|
||||
|
||||
import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIChatEndpoint = "/chat/completions"
|
||||
@@ -12,6 +15,12 @@ func resolveEndpointOverride(modelName, requestedEndpoint string) (string, bool)
|
||||
return "", false
|
||||
}
|
||||
info := registry.GetGlobalRegistry().GetModelInfo(modelName, "")
|
||||
if info == nil {
|
||||
baseModel := thinking.ParseSuffix(modelName).ModelName
|
||||
if baseModel != "" && baseModel != modelName {
|
||||
info = registry.GetGlobalRegistry().GetModelInfo(baseModel, "")
|
||||
}
|
||||
}
|
||||
if info == nil || len(info.SupportedEndpoints) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
29
sdk/api/handlers/openai/endpoint_compat_test.go
Normal file
29
sdk/api/handlers/openai/endpoint_compat_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func TestResolveEndpointOverride_StripsThinkingSuffix(t *testing.T) {
|
||||
const clientID = "test-endpoint-compat-suffix"
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
|
||||
{
|
||||
ID: "test-gemini-chat-only",
|
||||
SupportedEndpoints: []string{openAIChatEndpoint},
|
||||
},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(clientID)
|
||||
})
|
||||
|
||||
override, ok := resolveEndpointOverride("test-gemini-chat-only(high)", openAIResponsesEndpoint)
|
||||
if !ok {
|
||||
t.Fatalf("expected endpoint override to be resolved")
|
||||
}
|
||||
if override != openAIChatEndpoint {
|
||||
t.Fatalf("override endpoint = %q, want %q", override, openAIChatEndpoint)
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,9 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -22,6 +24,177 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
|
||||
if w == nil || len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if bytes.HasSuffix(chunk, []byte("\n\n")) || bytes.HasSuffix(chunk, []byte("\r\n\r\n")) {
|
||||
return
|
||||
}
|
||||
suffix := []byte("\n\n")
|
||||
if bytes.HasSuffix(chunk, []byte("\r\n")) {
|
||||
suffix = []byte("\r\n")
|
||||
} else if bytes.HasSuffix(chunk, []byte("\n")) {
|
||||
suffix = []byte("\n")
|
||||
}
|
||||
if _, err := w.Write(suffix); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type responsesSSEFramer struct {
|
||||
pending []byte
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
|
||||
if len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
if responsesSSENeedsLineBreak(f.pending, chunk) {
|
||||
f.pending = append(f.pending, '\n')
|
||||
}
|
||||
f.pending = append(f.pending, chunk...)
|
||||
for {
|
||||
frameLen := responsesSSEFrameLen(f.pending)
|
||||
if frameLen == 0 {
|
||||
break
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending[:frameLen])
|
||||
copy(f.pending, f.pending[frameLen:])
|
||||
f.pending = f.pending[:len(f.pending)-frameLen]
|
||||
}
|
||||
if len(bytes.TrimSpace(f.pending)) == 0 {
|
||||
f.pending = f.pending[:0]
|
||||
return
|
||||
}
|
||||
if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) {
|
||||
return
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending)
|
||||
f.pending = f.pending[:0]
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) Flush(w io.Writer) {
|
||||
if len(f.pending) == 0 {
|
||||
return
|
||||
}
|
||||
if len(bytes.TrimSpace(f.pending)) == 0 {
|
||||
f.pending = f.pending[:0]
|
||||
return
|
||||
}
|
||||
if !responsesSSECanEmitWithoutDelimiter(f.pending) {
|
||||
f.pending = f.pending[:0]
|
||||
return
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending)
|
||||
f.pending = f.pending[:0]
|
||||
}
|
||||
|
||||
func responsesSSEFrameLen(chunk []byte) int {
|
||||
if len(chunk) == 0 {
|
||||
return 0
|
||||
}
|
||||
lf := bytes.Index(chunk, []byte("\n\n"))
|
||||
crlf := bytes.Index(chunk, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0:
|
||||
if crlf < 0 {
|
||||
return 0
|
||||
}
|
||||
return crlf + 4
|
||||
case crlf < 0:
|
||||
return lf + 2
|
||||
case lf < crlf:
|
||||
return lf + 2
|
||||
default:
|
||||
return crlf + 4
|
||||
}
|
||||
}
|
||||
|
||||
func responsesSSENeedsMoreData(chunk []byte) bool {
|
||||
trimmed := bytes.TrimSpace(chunk)
|
||||
if len(trimmed) == 0 {
|
||||
return false
|
||||
}
|
||||
return responsesSSEHasField(trimmed, []byte("event:")) && !responsesSSEHasField(trimmed, []byte("data:"))
|
||||
}
|
||||
|
||||
func responsesSSEHasField(chunk []byte, prefix []byte) bool {
|
||||
s := chunk
|
||||
for len(s) > 0 {
|
||||
line := s
|
||||
if i := bytes.IndexByte(s, '\n'); i >= 0 {
|
||||
line = s[:i]
|
||||
s = s[i+1:]
|
||||
} else {
|
||||
s = nil
|
||||
}
|
||||
line = bytes.TrimSpace(line)
|
||||
if bytes.HasPrefix(line, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func responsesSSECanEmitWithoutDelimiter(chunk []byte) bool {
|
||||
trimmed := bytes.TrimSpace(chunk)
|
||||
if len(trimmed) == 0 || responsesSSENeedsMoreData(trimmed) || !responsesSSEHasField(trimmed, []byte("data:")) {
|
||||
return false
|
||||
}
|
||||
return responsesSSEDataLinesValid(trimmed)
|
||||
}
|
||||
|
||||
func responsesSSEDataLinesValid(chunk []byte) bool {
|
||||
s := chunk
|
||||
for len(s) > 0 {
|
||||
line := s
|
||||
if i := bytes.IndexByte(s, '\n'); i >= 0 {
|
||||
line = s[:i]
|
||||
s = s[i+1:]
|
||||
} else {
|
||||
s = nil
|
||||
}
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[len("data:"):])
|
||||
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if !json.Valid(data) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func responsesSSENeedsLineBreak(pending, chunk []byte) bool {
|
||||
if len(pending) == 0 || len(chunk) == 0 {
|
||||
return false
|
||||
}
|
||||
if bytes.HasSuffix(pending, []byte("\n")) || bytes.HasSuffix(pending, []byte("\r")) {
|
||||
return false
|
||||
}
|
||||
if chunk[0] == '\n' || chunk[0] == '\r' {
|
||||
return false
|
||||
}
|
||||
trimmed := bytes.TrimLeft(chunk, " \t")
|
||||
if len(trimmed) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range [][]byte{[]byte("data:"), []byte("event:"), []byte("id:"), []byte("retry:"), []byte(":")} {
|
||||
if bytes.HasPrefix(trimmed, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type OpenAIResponsesAPIHandler struct {
|
||||
@@ -234,6 +407,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
framer := &responsesSSEFramer{}
|
||||
|
||||
// Peek at the first chunk
|
||||
for {
|
||||
@@ -271,15 +445,11 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write first chunk logic (matching forwardResponsesStream)
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
framer.WriteChunk(c.Writer, chunk)
|
||||
flusher.Flush()
|
||||
|
||||
// Continue
|
||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, framer)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -397,16 +567,16 @@ func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, framer *responsesSSEFramer) {
|
||||
if framer == nil {
|
||||
framer = &responsesSSEFramer{}
|
||||
}
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
framer.WriteChunk(c.Writer, chunk)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
framer.Flush(c.Writer)
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
@@ -422,6 +592,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||
},
|
||||
WriteDone: func() {
|
||||
framer.Flush(c.Writer)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
})
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T
|
||||
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
body := recorder.Body.String()
|
||||
if !strings.Contains(body, `"type":"error"`) {
|
||||
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||
|
||||
142
sdk/api/handlers/openai/openai_responses_handlers_stream_test.go
Normal file
142
sdk/api/handlers/openai/openai_responses_handlers_stream_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||
}
|
||||
|
||||
return h, recorder, c, flusher
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 2)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}")
|
||||
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}")
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
body := recorder.Body.String()
|
||||
parts := strings.Split(strings.TrimSpace(body), "\n\n")
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), body)
|
||||
}
|
||||
|
||||
expectedPart1 := "data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}"
|
||||
if parts[0] != expectedPart1 {
|
||||
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
|
||||
}
|
||||
|
||||
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
|
||||
if parts[1] != expectedPart2 {
|
||||
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 3)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("event: response.created")
|
||||
data <- []byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}")
|
||||
data <- []byte("\n")
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
got := strings.TrimSuffix(recorder.Body.String(), "\n")
|
||||
want := "event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected split-event framing.\nGot: %q\nWant: %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamPreservesValidFullSSEEventChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 1)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
chunk := []byte("event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n")
|
||||
data <- chunk
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
got := strings.TrimSuffix(recorder.Body.String(), "\n")
|
||||
if got != string(chunk) {
|
||||
t.Fatalf("unexpected full-event framing.\nGot: %q\nWant: %q", got, string(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamBuffersSplitDataPayloadChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 2)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.created\"")
|
||||
data <- []byte(",\"response\":{\"id\":\"resp-1\"}}")
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
got := recorder.Body.String()
|
||||
want := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n\n"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected split-data framing.\nGot: %q\nWant: %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesSSENeedsLineBreakSkipsChunksThatAlreadyStartWithNewline(t *testing.T) {
|
||||
if responsesSSENeedsLineBreak([]byte("event: response.created"), []byte("\n")) {
|
||||
t.Fatal("expected no injected newline before newline-only chunk")
|
||||
}
|
||||
if responsesSSENeedsLineBreak([]byte("event: response.created"), []byte("\r\n")) {
|
||||
t.Fatal("expected no injected newline before CRLF chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamDropsIncompleteTrailingDataChunkOnFlush(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 1)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.created\"")
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
if got := recorder.Body.String(); got != "\n" {
|
||||
t.Fatalf("expected incomplete trailing data to be dropped on flush.\nGot: %q", got)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user