mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-30 09:18:12 +00:00
Compare commits
153 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8d1bc06ea | ||
|
|
d5930f4e44 | ||
|
|
9b7d7021af | ||
|
|
e41c22ef44 | ||
|
|
55271403fb | ||
|
|
36fba66619 | ||
|
|
b9b127a7ea | ||
|
|
2741e7b7b3 | ||
|
|
1767a56d4f | ||
|
|
779e6c2d2f | ||
|
|
73c831747b | ||
|
|
b8b89f34f4 | ||
|
|
1fa094dac6 | ||
|
|
f55754621f | ||
|
|
ac26e7db43 | ||
|
|
10b824fcac | ||
|
|
e5d3541b5a | ||
|
|
79755e76ea | ||
|
|
35f158d526 | ||
|
|
6962e09dd9 | ||
|
|
4c4cbd44da | ||
|
|
26eca8b6ba | ||
|
|
62b17f40a1 | ||
|
|
511b8a992e | ||
|
|
7dccc7ba2f | ||
|
|
70c90687fd | ||
|
|
8144ffd5c8 | ||
|
|
0ab977c236 | ||
|
|
224f0de353 | ||
|
|
6b45d311ec | ||
|
|
d54de441d3 | ||
|
|
7386a70724 | ||
|
|
1821bf7051 | ||
|
|
d42b5d4e78 | ||
|
|
1b7447b682 | ||
|
|
40dee4453a | ||
|
|
8902e1cccb | ||
|
|
de5fe71478 | ||
|
|
dcfbec2990 | ||
|
|
c95620f90e | ||
|
|
754f3bcbc3 | ||
|
|
36973d4a6f | ||
|
|
9613f0b3f9 | ||
|
|
274f29e26b | ||
|
|
c8e79c3787 | ||
|
|
8afef43887 | ||
|
|
c1083cbfc6 | ||
|
|
c89d19b300 | ||
|
|
1e6bc81cfd | ||
|
|
1a149475e0 | ||
|
|
e5166841db | ||
|
|
19c52bcb60 | ||
|
|
bb9b2d1758 | ||
|
|
7fa527193c | ||
|
|
ed0eb51b4d | ||
|
|
0e4f669c8b | ||
|
|
76c064c729 | ||
|
|
d2f652f436 | ||
|
|
6a452a54d5 | ||
|
|
9e5693e74f | ||
|
|
528b1a2307 | ||
|
|
0cc978ec1d | ||
|
|
d312422ab4 | ||
|
|
fee736933b | ||
|
|
09c92aa0b5 | ||
|
|
8c67b3ae64 | ||
|
|
000e4ceb4e | ||
|
|
5c99846ecf | ||
|
|
cc32f5ff61 | ||
|
|
fbff68b9e0 | ||
|
|
7e1a543b79 | ||
|
|
d475aaba96 | ||
|
|
1dc4ecb1b8 | ||
|
|
1315f710f5 | ||
|
|
96f55570f7 | ||
|
|
0906aeca87 | ||
|
|
7333619f15 | ||
|
|
97c0487add | ||
|
|
74b862d8b8 | ||
|
|
2db8df8e38 | ||
|
|
a576088d5f | ||
|
|
66ff916838 | ||
|
|
7b0453074e | ||
|
|
a000eb523d | ||
|
|
18a4fedc7f | ||
|
|
5d6cdccda0 | ||
|
|
1b7f4ac3e1 | ||
|
|
afc1a5b814 | ||
|
|
7ed38db54f | ||
|
|
28c10f4e69 | ||
|
|
6e12441a3b | ||
|
|
65c439c18d | ||
|
|
0ed2d16596 | ||
|
|
db335ac616 | ||
|
|
f3c59165d7 | ||
|
|
e6690cb447 | ||
|
|
35907416b8 | ||
|
|
e8bb350467 | ||
|
|
5331d51f27 | ||
|
|
755ca75879 | ||
|
|
2398ebad55 | ||
|
|
c1bf298216 | ||
|
|
e005208d76 | ||
|
|
d1df70d02f | ||
|
|
f81acd0760 | ||
|
|
636da4c932 | ||
|
|
cccb77b552 | ||
|
|
2bd646ad70 | ||
|
|
52c1fa025e | ||
|
|
680105f84d | ||
|
|
f7069e9548 | ||
|
|
7275e99b41 | ||
|
|
c28b65f849 | ||
|
|
793840cdb4 | ||
|
|
8f421de532 | ||
|
|
be2dd60ee7 | ||
|
|
ea3e0b713e | ||
|
|
8179d5a8a4 | ||
|
|
6fa7abe434 | ||
|
|
5135c22cd6 | ||
|
|
1e27990561 | ||
|
|
e1e9fc43c1 | ||
|
|
b2921518ac | ||
|
|
dd64adbeeb | ||
|
|
616d41c06a | ||
|
|
e0e337aeb9 | ||
|
|
d52839fced | ||
|
|
4022e69651 | ||
|
|
56073ded69 | ||
|
|
9738a53f49 | ||
|
|
be3f8dbf7e | ||
|
|
9c6c3612a8 | ||
|
|
19e1a4447a | ||
|
|
7c2ad4cda2 | ||
|
|
fb95813fbf | ||
|
|
db63f9b5d6 | ||
|
|
25f6c4a250 | ||
|
|
b24ae74216 | ||
|
|
59ad8f40dc | ||
|
|
ff03dc6a2c | ||
|
|
dc7187ca5b | ||
|
|
b1dcff778c | ||
|
|
c1241a98e2 | ||
|
|
8d8f5970ee | ||
|
|
f90120f846 | ||
|
|
0b94d36c4a | ||
|
|
152c310bb7 | ||
|
|
f6bbca35ab | ||
|
|
c8cee6a209 | ||
|
|
5c817a9b42 | ||
|
|
5da0decef6 | ||
|
|
5b6342e6ac | ||
|
|
c3762328a5 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
# Binaries
|
||||
cli-proxy-api
|
||||
cliproxy
|
||||
/server
|
||||
*.exe
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ builds:
|
||||
- linux
|
||||
- windows
|
||||
- darwin
|
||||
- freebsd
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# CLIProxyAPI Plus
|
||||
|
||||
[English](README.md) | 中文
|
||||
[English](README.md) | 中文 | [日本語](README_JA.md)
|
||||
|
||||
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
||||
|
||||
|
||||
199
README_JA.md
Normal file
199
README_JA.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# CLI Proxy API
|
||||
|
||||
[English](README.md) | [中文](README_CN.md) | 日本語
|
||||
|
||||
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
|
||||
|
||||
OAuth経由でOpenAI Codex(GPTモデル)およびClaude Codeもサポートしています。
|
||||
|
||||
ローカルまたはマルチアカウントのCLIアクセスを、OpenAI(Responses含む)/Gemini/Claude互換のクライアントやSDKで利用できます。
|
||||
|
||||
## スポンサー
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
|
||||
|
||||
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7および(GLM-5はProユーザーのみ利用可能)モデルを10以上の人気AIコーディングツール(Claude Code、Cline、Roo Codeなど)で利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
|
||||
|
||||
GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
|
||||
---
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割(90% OFF)</b> という驚異的な価格でご利用いただけます!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
|
||||
<td>LingtrueAPIのスポンサーシップに感謝します!LingtrueAPIはグローバルな大規模モデルAPIリレーサービスプラットフォームで、Claude Code、Codex、GeminiなどのトップモデルAPI呼び出しサービスを提供し、ユーザーが低コストかつ高い安定性で世界中のAI能力に接続できるよう支援しています。LingtrueAPIは本ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.lingtrue.com/register">こちらのリンク</a>から登録し、初回チャージ時にプロモーションコード「LingtrueAPI」を入力すると10%割引になります。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## 概要
|
||||
|
||||
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
|
||||
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
|
||||
- OAuthログインによるClaude Codeサポート
|
||||
- OAuthログインによるQwen Codeサポート
|
||||
- OAuthログインによるiFlowサポート
|
||||
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
|
||||
- ストリーミングおよび非ストリーミングレスポンス
|
||||
- 関数呼び出し/ツールのサポート
|
||||
- マルチモーダル入力サポート(テキストと画像)
|
||||
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||
- シンプルなCLI認証フロー(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||
- Generative Language APIキーのサポート
|
||||
- AI Studioビルドのマルチアカウント負荷分散
|
||||
- Gemini CLIのマルチアカウント負荷分散
|
||||
- Claude Codeのマルチアカウント負荷分散
|
||||
- Qwen Codeのマルチアカウント負荷分散
|
||||
- iFlowのマルチアカウント負荷分散
|
||||
- OpenAI Codexのマルチアカウント負荷分散
|
||||
- 設定によるOpenAI互換アップストリームプロバイダー(例:OpenRouter)
|
||||
- プロキシ埋め込み用の再利用可能なGo SDK(`docs/sdk-usage.md`を参照)
|
||||
|
||||
## はじめに
|
||||
|
||||
CLIProxyAPIガイド:[https://help.router-for.me/](https://help.router-for.me/)
|
||||
|
||||
## 管理API
|
||||
|
||||
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
|
||||
|
||||
## Amp CLIサポート
|
||||
|
||||
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます:
|
||||
|
||||
- Ampの APIパターン用のプロバイダールートエイリアス(`/api/provider/{provider}/v1...`)
|
||||
- OAuth認証およびアカウント機能用の管理プロキシ
|
||||
- 自動ルーティングによるスマートモデルフォールバック
|
||||
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
||||
|
||||
特定のバックエンド系統のリクエスト/レスポンス形状が必要な場合は、統合された `/v1/...` エンドポイントよりも provider-specific のパスを優先してください。
|
||||
|
||||
- messages 系のバックエンドには `/api/provider/{provider}/v1/messages`
|
||||
- モデル単位の generate 系エンドポイントには `/api/provider/{provider}/v1beta/models/...`
|
||||
- chat-completions 系のバックエンドには `/api/provider/{provider}/v1/chat/completions`
|
||||
|
||||
これらのパスはプロトコル面の選択には役立ちますが、同じクライアント向けモデル名が複数バックエンドで再利用されている場合、それだけで推論実行系が一意に固定されるわけではありません。実際の推論ルーティングは、引き続きリクエスト内の model/alias 解決に従います。厳密にバックエンドを固定したい場合は、一意な alias や prefix を使うか、クライアント向けモデル名の重複自体を避けてください。
|
||||
|
||||
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDKドキュメント
|
||||
|
||||
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
|
||||
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
|
||||
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
|
||||
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
|
||||
- カスタムプロバイダーの例:`examples/custom-provider`
|
||||
|
||||
## コントリビューション
|
||||
|
||||
コントリビューションを歓迎します!お気軽にPull Requestを送ってください。
|
||||
|
||||
1. リポジトリをフォーク
|
||||
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`)
|
||||
3. 変更をコミット(`git commit -m 'Add some amazing feature'`)
|
||||
4. ブランチにプッシュ(`git push origin feature/amazing-feature`)
|
||||
5. Pull Requestを作成
|
||||
|
||||
## 関連プロジェクト
|
||||
|
||||
CLIProxyAPIをベースにした以下のプロジェクトがあります:
|
||||
|
||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
||||
|
||||
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
|
||||
|
||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
||||
|
||||
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
|
||||
|
||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
||||
|
||||
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデル(Gemini、Codex、Antigravity)を即座に切り替えるCLIラッパー - APIキー不要
|
||||
|
||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
||||
|
||||
CLIProxyAPI管理用のmacOSネイティブGUI:OAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
|
||||
|
||||
### [CodMate](https://github.com/loocor/CodMate)
|
||||
|
||||
CLI AIセッション(Codex、Claude Code、Gemini CLI)を管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
|
||||
|
||||
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||
|
||||
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
|
||||
|
||||
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||
|
||||
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
|
||||
|
||||
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
|
||||
|
||||
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
|
||||
|
||||
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||
|
||||
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
|
||||
|
||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||
|
||||
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替え(Main / Plus)、自動ダウンロードおよび自動更新に対応
|
||||
|
||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||
|
||||
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
|
||||
|
||||
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
|
||||
|
||||
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
|
||||
|
||||
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||
|
||||
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
|
||||
|
||||
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
|
||||
|
||||
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LAN(ローカルエリアネットワーク)を介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
## その他の選択肢
|
||||
|
||||
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです:
|
||||
|
||||
### [9Router](https://github.com/decolua/9router)
|
||||
|
||||
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換(OpenAI/Claude/Gemini/Ollama)、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツール(Cursor、Claude Code、Cline、RooCode)のサポートをゼロから構築 - APIキー不要
|
||||
|
||||
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
|
||||
|
||||
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
|
||||
|
||||
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:スマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
## ライセンス
|
||||
|
||||
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。
|
||||
BIN
assets/bmoplus.png
Normal file
BIN
assets/bmoplus.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
BIN
assets/lingtrue.png
Normal file
BIN
assets/lingtrue.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
20
cmd/mcpdebug/main.go
Normal file
20
cmd/mcpdebug/main.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Encode MCP result with empty execId
|
||||
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
|
||||
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
|
||||
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
|
||||
|
||||
// Write to file for analysis
|
||||
os.WriteFile("mcp_result.bin", resultBytes)
|
||||
fmt.Println("Wrote mcp_result.bin")
|
||||
}
|
||||
32
cmd/protocheck/main.go
Normal file
32
cmd/protocheck/main.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ecm := cursorproto.NewMsg("ExecClientMessage")
|
||||
|
||||
// Try different field names
|
||||
names := []string{
|
||||
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
|
||||
"shell_result", "shellResult",
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
fd := ecm.Descriptor().Fields().ByName(name)
|
||||
if fd != nil {
|
||||
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
|
||||
} else {
|
||||
fmt.Printf("Field %q NOT FOUND\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
// List all fields
|
||||
fmt.Println("\nAll fields in ExecClientMessage:")
|
||||
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
|
||||
f := ecm.Descriptor().Fields().Get(i)
|
||||
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,7 @@ func main() {
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kimiLogin bool
|
||||
var cursorLogin bool
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
var kiroAWSLogin bool
|
||||
@@ -95,6 +96,7 @@ func main() {
|
||||
var kiroIDCRegion string
|
||||
var kiroIDCFlow string
|
||||
var githubCopilotLogin bool
|
||||
var codeBuddyLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
@@ -103,6 +105,7 @@ func main() {
|
||||
var standalone bool
|
||||
var noIncognito bool
|
||||
var useIncognito bool
|
||||
var localModel bool
|
||||
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
@@ -121,6 +124,7 @@ func main() {
|
||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
|
||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
@@ -131,12 +135,14 @@ func main() {
|
||||
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
||||
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||
flag.BoolVar(&codeBuddyLogin, "codebuddy-login", false, "Login to CodeBuddy using browser OAuth flow")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
out := flag.CommandLine.Output()
|
||||
@@ -514,6 +520,9 @@ func main() {
|
||||
} else if githubCopilotLogin {
|
||||
// Handle GitHub Copilot login
|
||||
cmd.DoGitHubCopilotLogin(cfg, options)
|
||||
} else if codeBuddyLogin {
|
||||
// Handle CodeBuddy login
|
||||
cmd.DoCodeBuddyLogin(cfg, options)
|
||||
} else if codexLogin {
|
||||
// Handle Codex login
|
||||
cmd.DoCodexLogin(cfg, options)
|
||||
@@ -537,6 +546,8 @@ func main() {
|
||||
cmd.DoGitLabTokenLogin(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if cursorLogin {
|
||||
cmd.DoCursorLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
@@ -578,11 +589,16 @@ func main() {
|
||||
cmd.WaitForCloudDeploy()
|
||||
return
|
||||
}
|
||||
if localModel && (!tuiMode || standalone) {
|
||||
log.Info("Local model mode: using embedded model catalog, remote model updates disabled")
|
||||
}
|
||||
if tuiMode {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
if !localModel {
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
}
|
||||
hook := tui.NewLogHook(2000)
|
||||
hook.SetFormatter(&logging.LogFormatter{})
|
||||
log.AddHook(hook)
|
||||
@@ -655,7 +671,9 @@ func main() {
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
if !localModel {
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
}
|
||||
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
|
||||
@@ -25,6 +25,10 @@ remote-management:
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# Disable automatic periodic background updates of the management panel from GitHub (default: false).
|
||||
# When enabled, the panel is only downloaded on first access if missing, and never auto-updated afterward.
|
||||
# disable-auto-update-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
||||
|
||||
@@ -175,12 +179,19 @@ nonstream-keepalive-interval: 0
|
||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
||||
# when the client omits them, while OS/arch remain runtime-derived. When
|
||||
# stabilize-device-profile is enabled, OS/arch stay pinned to the baseline values below,
|
||||
# while user-agent/package-version/runtime-version seed a software fingerprint that can
|
||||
# still upgrade to newer official Claude client versions.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# os: "MacOS"
|
||||
# arch: "arm64"
|
||||
# timeout: "600"
|
||||
# stabilize-device-profile: false # optional, default false; set true to enable per-auth/API-key fingerprint pinning
|
||||
|
||||
# Default headers for Codex OAuth model requests.
|
||||
# These are used only for file-backed/OAuth Codex requests when the client
|
||||
@@ -231,7 +242,9 @@ nonstream-keepalive-interval: 0
|
||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||
# models: # The models supported by the provider.
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# thinking: # optional: omit to default to levels ["low","medium","high"]
|
||||
# levels: ["low", "medium", "high"]
|
||||
# # You may repeat the same alias to build an internal model pool.
|
||||
# # The client still sees only one alias in the model list.
|
||||
# # Requests to that alias will round-robin across the upstream names below,
|
||||
@@ -300,6 +313,10 @@ nonstream-keepalive-interval: 0
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||
# you select the protocol surface, but inference backend selection can still follow the resolved
|
||||
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
# oauth-model-alias:
|
||||
# antigravity:
|
||||
|
||||
@@ -52,11 +52,11 @@ func init() {
|
||||
sdktr.Register(fOpenAI, fMyProv,
|
||||
func(model string, raw []byte, stream bool) []byte { return raw },
|
||||
sdktr.ResponseTransform{
|
||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
|
||||
return []string{string(raw)}
|
||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) [][]byte {
|
||||
return [][]byte{raw}
|
||||
},
|
||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
|
||||
return string(raw)
|
||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []byte {
|
||||
return raw
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -91,8 +91,8 @@ require (
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -28,6 +29,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
@@ -66,8 +68,10 @@ type callbackForwarder struct {
|
||||
}
|
||||
|
||||
var (
|
||||
callbackForwardersMu sync.Mutex
|
||||
callbackForwarders = make(map[int]*callbackForwarder)
|
||||
callbackForwardersMu sync.Mutex
|
||||
callbackForwarders = make(map[int]*callbackForwarder)
|
||||
errAuthFileMustBeJSON = errors.New("auth file must be .json")
|
||||
errAuthFileNotFound = errors.New("auth file not found")
|
||||
)
|
||||
|
||||
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
|
||||
@@ -341,6 +345,21 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
||||
emailValue := gjson.GetBytes(data, "email").String()
|
||||
fileData["type"] = typeValue
|
||||
fileData["email"] = emailValue
|
||||
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
|
||||
switch pv.Type {
|
||||
case gjson.Number:
|
||||
fileData["priority"] = int(pv.Int())
|
||||
case gjson.String:
|
||||
if parsed, errAtoi := strconv.Atoi(strings.TrimSpace(pv.String())); errAtoi == nil {
|
||||
fileData["priority"] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if nv := gjson.GetBytes(data, "note"); nv.Exists() && nv.Type == gjson.String {
|
||||
if trimmed := strings.TrimSpace(nv.String()); trimmed != "" {
|
||||
fileData["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
files = append(files, fileData)
|
||||
@@ -424,6 +443,37 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
if claims := extractCodexIDTokenClaims(auth); claims != nil {
|
||||
entry["id_token"] = claims
|
||||
}
|
||||
// Expose priority from Attributes (set by synthesizer from JSON "priority" field).
|
||||
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
|
||||
if p := strings.TrimSpace(authAttribute(auth, "priority")); p != "" {
|
||||
if parsed, err := strconv.Atoi(p); err == nil {
|
||||
entry["priority"] = parsed
|
||||
}
|
||||
} else if auth.Metadata != nil {
|
||||
if rawPriority, ok := auth.Metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
entry["priority"] = int(v)
|
||||
case int:
|
||||
entry["priority"] = v
|
||||
case string:
|
||||
if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
entry["priority"] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Expose note from Attributes (set by synthesizer from JSON "note" field).
|
||||
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
|
||||
if note := strings.TrimSpace(authAttribute(auth, "note")); note != "" {
|
||||
entry["note"] = note
|
||||
} else if auth.Metadata != nil {
|
||||
if rawNote, ok := auth.Metadata["note"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(rawNote); trimmed != "" {
|
||||
entry["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
@@ -501,10 +551,23 @@ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
|
||||
}
|
||||
|
||||
func isUnsafeAuthFileName(name string) bool {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return true
|
||||
}
|
||||
if strings.ContainsAny(name, "/\\") {
|
||||
return true
|
||||
}
|
||||
if filepath.VolumeName(name) != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Download single auth file by name
|
||||
func (h *Handler) DownloadAuthFile(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
name := strings.TrimSpace(c.Query("name"))
|
||||
if isUnsafeAuthFileName(name) {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
@@ -533,36 +596,61 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
if file, err := c.FormFile("file"); err == nil && file != nil {
|
||||
name := filepath.Base(file.Filename)
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
c.JSON(400, gin.H{"error": "file must be .json"})
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(h.cfg.AuthDir, name)
|
||||
if !filepath.IsAbs(dst) {
|
||||
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
|
||||
dst = abs
|
||||
}
|
||||
}
|
||||
if errSave := c.SaveUploadedFile(file, dst); errSave != nil {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)})
|
||||
return
|
||||
}
|
||||
data, errRead := os.ReadFile(dst)
|
||||
if errRead != nil {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)})
|
||||
return
|
||||
}
|
||||
if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil {
|
||||
c.JSON(500, gin.H{"error": errReg.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
|
||||
fileHeaders, errMultipart := h.multipartAuthFileHeaders(c)
|
||||
if errMultipart != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid multipart form: %v", errMultipart)})
|
||||
return
|
||||
}
|
||||
name := c.Query("name")
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
if len(fileHeaders) == 1 {
|
||||
if _, errUpload := h.storeUploadedAuthFile(ctx, fileHeaders[0]); errUpload != nil {
|
||||
if errors.Is(errUpload, errAuthFileMustBeJSON) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "file must be .json"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errUpload.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
if len(fileHeaders) > 1 {
|
||||
uploaded := make([]string, 0, len(fileHeaders))
|
||||
failed := make([]gin.H, 0)
|
||||
for _, file := range fileHeaders {
|
||||
name, errUpload := h.storeUploadedAuthFile(ctx, file)
|
||||
if errUpload != nil {
|
||||
failureName := ""
|
||||
if file != nil {
|
||||
failureName = filepath.Base(file.Filename)
|
||||
}
|
||||
msg := errUpload.Error()
|
||||
if errors.Is(errUpload, errAuthFileMustBeJSON) {
|
||||
msg = "file must be .json"
|
||||
}
|
||||
failed = append(failed, gin.H{"name": failureName, "error": msg})
|
||||
continue
|
||||
}
|
||||
uploaded = append(uploaded, name)
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
c.JSON(http.StatusMultiStatus, gin.H{
|
||||
"status": "partial",
|
||||
"uploaded": len(uploaded),
|
||||
"files": uploaded,
|
||||
"failed": failed,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "uploaded": len(uploaded), "files": uploaded})
|
||||
return
|
||||
}
|
||||
if c.ContentType() == "multipart/form-data" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no files uploaded"})
|
||||
return
|
||||
}
|
||||
name := strings.TrimSpace(c.Query("name"))
|
||||
if isUnsafeAuthFileName(name) {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
@@ -575,17 +663,7 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
if !filepath.IsAbs(dst) {
|
||||
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
|
||||
dst = abs
|
||||
}
|
||||
}
|
||||
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)})
|
||||
return
|
||||
}
|
||||
if err = h.registerAuthFromFile(ctx, dst, data); err != nil {
|
||||
if err = h.writeAuthFile(ctx, filepath.Base(name), data); err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -632,11 +710,182 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "deleted": deleted})
|
||||
return
|
||||
}
|
||||
name := c.Query("name")
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
|
||||
names, errNames := requestedAuthFileNamesForDelete(c)
|
||||
if errNames != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errNames.Error()})
|
||||
return
|
||||
}
|
||||
if len(names) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
if len(names) == 1 {
|
||||
if _, status, errDelete := h.deleteAuthFileByName(ctx, names[0]); errDelete != nil {
|
||||
c.JSON(status, gin.H{"error": errDelete.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
|
||||
deletedFiles := make([]string, 0, len(names))
|
||||
failed := make([]gin.H, 0)
|
||||
for _, name := range names {
|
||||
deletedName, _, errDelete := h.deleteAuthFileByName(ctx, name)
|
||||
if errDelete != nil {
|
||||
failed = append(failed, gin.H{"name": name, "error": errDelete.Error()})
|
||||
continue
|
||||
}
|
||||
deletedFiles = append(deletedFiles, deletedName)
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
c.JSON(http.StatusMultiStatus, gin.H{
|
||||
"status": "partial",
|
||||
"deleted": len(deletedFiles),
|
||||
"files": deletedFiles,
|
||||
"failed": failed,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "deleted": len(deletedFiles), "files": deletedFiles})
|
||||
}
|
||||
|
||||
func (h *Handler) multipartAuthFileHeaders(c *gin.Context) ([]*multipart.FileHeader, error) {
|
||||
if h == nil || c == nil || c.ContentType() != "multipart/form-data" {
|
||||
return nil, nil
|
||||
}
|
||||
form, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if form == nil || len(form.File) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(form.File))
|
||||
for key := range form.File {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
headers := make([]*multipart.FileHeader, 0)
|
||||
for _, key := range keys {
|
||||
headers = append(headers, form.File[key]...)
|
||||
}
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
func (h *Handler) storeUploadedAuthFile(ctx context.Context, file *multipart.FileHeader) (string, error) {
|
||||
if file == nil {
|
||||
return "", fmt.Errorf("no file uploaded")
|
||||
}
|
||||
name := filepath.Base(strings.TrimSpace(file.Filename))
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
return "", errAuthFileMustBeJSON
|
||||
}
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open uploaded file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
data, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read uploaded file: %w", err)
|
||||
}
|
||||
if err := h.writeAuthFile(ctx, name, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func (h *Handler) writeAuthFile(ctx context.Context, name string, data []byte) error {
|
||||
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
if !filepath.IsAbs(dst) {
|
||||
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
|
||||
dst = abs
|
||||
}
|
||||
}
|
||||
auth, err := h.buildAuthFromFileData(dst, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
|
||||
return fmt.Errorf("failed to write file: %w", errWrite)
|
||||
}
|
||||
if err := h.upsertAuthRecord(ctx, auth); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func requestedAuthFileNamesForDelete(c *gin.Context) ([]string, error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
names := uniqueAuthFileNames(c.QueryArray("name"))
|
||||
if len(names) > 0 {
|
||||
return names, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body")
|
||||
}
|
||||
body = bytes.TrimSpace(body)
|
||||
if len(body) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var objectBody struct {
|
||||
Name string `json:"name"`
|
||||
Names []string `json:"names"`
|
||||
}
|
||||
if body[0] == '[' {
|
||||
var arrayBody []string
|
||||
if err := json.Unmarshal(body, &arrayBody); err != nil {
|
||||
return nil, fmt.Errorf("invalid request body")
|
||||
}
|
||||
return uniqueAuthFileNames(arrayBody), nil
|
||||
}
|
||||
if err := json.Unmarshal(body, &objectBody); err != nil {
|
||||
return nil, fmt.Errorf("invalid request body")
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(objectBody.Names)+1)
|
||||
if strings.TrimSpace(objectBody.Name) != "" {
|
||||
out = append(out, objectBody.Name)
|
||||
}
|
||||
out = append(out, objectBody.Names...)
|
||||
return uniqueAuthFileNames(out), nil
|
||||
}
|
||||
|
||||
func uniqueAuthFileNames(names []string) []string {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(names))
|
||||
out := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string, int, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if isUnsafeAuthFileName(name) {
|
||||
return "", http.StatusBadRequest, fmt.Errorf("invalid name")
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
targetID := ""
|
||||
@@ -653,22 +902,19 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
}
|
||||
if errRemove := os.Remove(targetPath); errRemove != nil {
|
||||
if os.IsNotExist(errRemove) {
|
||||
c.JSON(404, gin.H{"error": "file not found"})
|
||||
} else {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)})
|
||||
return filepath.Base(name), http.StatusNotFound, errAuthFileNotFound
|
||||
}
|
||||
return
|
||||
return filepath.Base(name), http.StatusInternalServerError, fmt.Errorf("failed to remove file: %w", errRemove)
|
||||
}
|
||||
if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil {
|
||||
c.JSON(500, gin.H{"error": errDeleteRecord.Error()})
|
||||
return
|
||||
return filepath.Base(name), http.StatusInternalServerError, errDeleteRecord
|
||||
}
|
||||
if targetID != "" {
|
||||
h.disableAuth(ctx, targetID)
|
||||
} else {
|
||||
h.disableAuth(ctx, targetPath)
|
||||
}
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
return filepath.Base(name), http.StatusOK, nil
|
||||
}
|
||||
|
||||
func (h *Handler) findAuthForDelete(name string) *coreauth.Auth {
|
||||
@@ -702,10 +948,25 @@ func (h *Handler) authIDForPath(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
path = filepath.Clean(path)
|
||||
if !filepath.IsAbs(path) {
|
||||
if abs, errAbs := filepath.Abs(path); errAbs == nil {
|
||||
path = abs
|
||||
}
|
||||
}
|
||||
id := path
|
||||
if h != nil && h.cfg != nil {
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if resolvedAuthDir, errResolve := util.ResolveAuthDir(authDir); errResolve == nil && resolvedAuthDir != "" {
|
||||
authDir = resolvedAuthDir
|
||||
}
|
||||
if authDir != "" {
|
||||
authDir = filepath.Clean(authDir)
|
||||
if !filepath.IsAbs(authDir) {
|
||||
if abs, errAbs := filepath.Abs(authDir); errAbs == nil {
|
||||
authDir = abs
|
||||
}
|
||||
}
|
||||
if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
@@ -722,19 +983,27 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
if h.authManager == nil {
|
||||
return nil
|
||||
}
|
||||
auth, err := h.buildAuthFromFileData(path, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.upsertAuthRecord(ctx, auth)
|
||||
}
|
||||
|
||||
func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Auth, error) {
|
||||
if path == "" {
|
||||
return fmt.Errorf("auth path is empty")
|
||||
return nil, fmt.Errorf("auth path is empty")
|
||||
}
|
||||
if data == nil {
|
||||
var err error
|
||||
data, err = os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read auth file: %w", err)
|
||||
return nil, fmt.Errorf("failed to read auth file: %w", err)
|
||||
}
|
||||
}
|
||||
metadata := make(map[string]any)
|
||||
if err := json.Unmarshal(data, &metadata); err != nil {
|
||||
return fmt.Errorf("invalid auth file: %w", err)
|
||||
return nil, fmt.Errorf("invalid auth file: %w", err)
|
||||
}
|
||||
provider, _ := metadata["type"].(string)
|
||||
if provider == "" {
|
||||
@@ -768,13 +1037,25 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
if hasLastRefresh {
|
||||
auth.LastRefreshedAt = lastRefresh
|
||||
}
|
||||
if existing, ok := h.authManager.GetByID(authID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
if !hasLastRefresh {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
if h != nil && h.authManager != nil {
|
||||
if existing, ok := h.authManager.GetByID(authID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
if !hasLastRefresh {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
}
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
auth.Runtime = existing.Runtime
|
||||
}
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
auth.Runtime = existing.Runtime
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (h *Handler) upsertAuthRecord(ctx context.Context, auth *coreauth.Auth) error {
|
||||
if h == nil || h.authManager == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
if existing, ok := h.authManager.GetByID(auth.ID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
_, err := h.authManager.Update(ctx, auth)
|
||||
return err
|
||||
}
|
||||
@@ -848,7 +1129,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
@@ -860,6 +1141,7 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
Note *string `json:"note"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
@@ -902,14 +1184,32 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
changed = true
|
||||
}
|
||||
if req.Priority != nil {
|
||||
if req.Priority != nil || req.Note != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
if targetAuth.Attributes == nil {
|
||||
targetAuth.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
if req.Priority != nil {
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
delete(targetAuth.Attributes, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority)
|
||||
}
|
||||
}
|
||||
if req.Note != nil {
|
||||
trimmedNote := strings.TrimSpace(*req.Note)
|
||||
if trimmedNote == "" {
|
||||
delete(targetAuth.Metadata, "note")
|
||||
delete(targetAuth.Attributes, "note")
|
||||
} else {
|
||||
targetAuth.Metadata["note"] = trimmedNote
|
||||
targetAuth.Attributes["note"] = trimmedNote
|
||||
}
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
@@ -1838,9 +2138,6 @@ func (h *Handler) RequestGitLabToken(c *gin.Context) {
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata["oauth_client_id"] = clientID
|
||||
if clientSecret != "" {
|
||||
metadata["oauth_client_secret"] = clientSecret
|
||||
}
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := primaryGitLabEmail(user); email != "" {
|
||||
metadata["email"] = email
|
||||
@@ -3408,3 +3705,84 @@ func (h *Handler) RequestKiloToken(c *gin.Context) {
|
||||
"verification_uri": resp.VerificationURL,
|
||||
})
|
||||
}
|
||||
|
||||
// RequestCursorToken initiates the Cursor PKCE authentication flow.
|
||||
// Supports multiple accounts via ?label=xxx query parameter.
|
||||
// The user opens the returned URL in a browser, logs in, and the server polls
|
||||
// until the authentication completes.
|
||||
func (h *Handler) RequestCursorToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
label := strings.TrimSpace(c.Query("label"))
|
||||
log.Infof("Initializing Cursor authentication (label=%q)...", label)
|
||||
|
||||
authParams, err := cursorauth.GenerateAuthParams()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate Cursor auth params: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate auth params"})
|
||||
return
|
||||
}
|
||||
|
||||
state := fmt.Sprintf("cur-%d", time.Now().UnixNano())
|
||||
RegisterOAuthSession(state, "cursor")
|
||||
|
||||
go func() {
|
||||
log.Info("Waiting for Cursor authentication...")
|
||||
log.Infof("Open this URL in your browser: %s", authParams.LoginURL)
|
||||
|
||||
tokens, errPoll := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
|
||||
if errPoll != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed: "+errPoll.Error())
|
||||
log.Errorf("Cursor authentication failed: %v", errPoll)
|
||||
return
|
||||
}
|
||||
|
||||
// Build metadata
|
||||
metadata := map[string]any{
|
||||
"type": "cursor",
|
||||
"access_token": tokens.AccessToken,
|
||||
"refresh_token": tokens.RefreshToken,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
// Extract expiry and account identity from JWT
|
||||
expiry := cursorauth.GetTokenExpiry(tokens.AccessToken)
|
||||
if !expiry.IsZero() {
|
||||
metadata["expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// Auto-identify account from JWT sub claim for multi-account support
|
||||
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
|
||||
subHash := cursorauth.SubToShortHash(sub)
|
||||
if sub != "" {
|
||||
metadata["sub"] = sub
|
||||
}
|
||||
|
||||
fileName := cursorauth.CredentialFileName(label, subHash)
|
||||
displayLabel := cursorauth.DisplayLabel(label, subHash)
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "cursor",
|
||||
FileName: fileName,
|
||||
Label: displayLabel,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save Cursor tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save tokens")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Cursor authentication successful! Token saved to %s", savedPath)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("cursor")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ok",
|
||||
"url": authParams.LoginURL,
|
||||
"state": state,
|
||||
})
|
||||
}
|
||||
|
||||
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestUploadAuthFile_BatchMultipart(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
|
||||
files := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`},
|
||||
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
for _, file := range files {
|
||||
part, err := writer.CreateFormFile("file", file.name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create multipart file: %v", err)
|
||||
}
|
||||
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||
t.Fatalf("failed to write multipart content: %v", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("failed to close multipart writer: %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
ctx.Request = req
|
||||
|
||||
h.UploadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) {
|
||||
t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"])
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
fullPath := filepath.Join(authDir, file.name)
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
t.Fatalf("expected uploaded file %s to exist: %v", file.name, err)
|
||||
}
|
||||
if string(data) != file.content {
|
||||
t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
auths := manager.List()
|
||||
if len(auths) != len(files) {
|
||||
t.Fatalf("expected %d auth entries, got %d", len(files), len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
|
||||
existingName := "alpha.json"
|
||||
existingContent := `{"type":"codex","email":"alpha@example.com"}`
|
||||
if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil {
|
||||
t.Fatalf("failed to seed existing auth file: %v", err)
|
||||
}
|
||||
|
||||
files := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{name: existingName, content: `{"type":"codex"`},
|
||||
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
for _, file := range files {
|
||||
part, err := writer.CreateFormFile("file", file.name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create multipart file: %v", err)
|
||||
}
|
||||
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||
t.Fatalf("failed to write multipart content: %v", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("failed to close multipart writer: %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
ctx.Request = req
|
||||
|
||||
h.UploadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusMultiStatus {
|
||||
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(authDir, existingName))
|
||||
if err != nil {
|
||||
t.Fatalf("expected existing auth file to remain readable: %v", err)
|
||||
}
|
||||
if string(data) != existingContent {
|
||||
t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data))
|
||||
}
|
||||
|
||||
betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("expected valid auth file to be created: %v", err)
|
||||
}
|
||||
if string(betaData) != files[1].content {
|
||||
t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAuthFile_BatchQuery(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
files := []string{"alpha.json", "beta.json"}
|
||||
for _, name := range files {
|
||||
if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil {
|
||||
t.Fatalf("failed to write auth file %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(
|
||||
http.MethodDelete,
|
||||
"/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]),
|
||||
nil,
|
||||
)
|
||||
ctx.Request = req
|
||||
|
||||
h.DeleteAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) {
|
||||
t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"])
|
||||
}
|
||||
|
||||
for _, name := range files {
|
||||
if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
fileName := "download-user.json"
|
||||
expected := []byte(`{"type":"codex"}`)
|
||||
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := rec.Body.Bytes(); string(got) != string(expected) {
|
||||
t.Fatalf("unexpected download content: %q", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
|
||||
|
||||
for _, name := range []string{
|
||||
"../external/secret.json",
|
||||
`..\\external\\secret.json`,
|
||||
"nested/secret.json",
|
||||
`nested\\secret.json`,
|
||||
} {
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
//go:build windows
|
||||
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
authDir := filepath.Join(tempDir, "auth")
|
||||
externalDir := filepath.Join(tempDir, "external")
|
||||
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(externalDir, 0o700); err != nil {
|
||||
t.Fatalf("failed to create external dir: %v", err)
|
||||
}
|
||||
|
||||
secretName := "secret.json"
|
||||
secretPath := filepath.Join(externalDir, secretName)
|
||||
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
|
||||
t.Fatalf("failed to write external file: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
|
||||
nil,
|
||||
)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -682,6 +682,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
|
||||
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
|
||||
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package codebuddy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
BaseURL = "https://copilot.tencent.com"
|
||||
DefaultDomain = "www.codebuddy.cn"
|
||||
UserAgent = "CLI/2.63.2 CodeBuddy/2.63.2"
|
||||
|
||||
codeBuddyStatePath = "/v2/plugin/auth/state"
|
||||
codeBuddyTokenPath = "/v2/plugin/auth/token"
|
||||
codeBuddyRefreshPath = "/v2/plugin/auth/token/refresh"
|
||||
pollInterval = 5 * time.Second
|
||||
maxPollDuration = 5 * time.Minute
|
||||
codeLoginPending = 11217
|
||||
codeSuccess = 0
|
||||
)
|
||||
|
||||
type CodeBuddyAuth struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewCodeBuddyAuth(cfg *config.Config) *CodeBuddyAuth {
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
httpClient = util.SetProxy(&cfg.SDKConfig, httpClient)
|
||||
}
|
||||
return &CodeBuddyAuth{httpClient: httpClient, cfg: cfg, baseURL: BaseURL}
|
||||
}
|
||||
|
||||
// AuthState holds the state and auth URL returned by the auth state API.
|
||||
type AuthState struct {
|
||||
State string
|
||||
AuthURL string
|
||||
}
|
||||
|
||||
// FetchAuthState calls POST /v2/plugin/auth/state?platform=CLI to get the state and login URL.
|
||||
func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) {
|
||||
stateURL := fmt.Sprintf("%s%s?platform=CLI", a.baseURL, codeBuddyStatePath)
|
||||
body := []byte("{}")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, stateURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
|
||||
}
|
||||
|
||||
requestID := uuid.NewString()
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
req.Header.Set("X-Domain", "copilot.tencent.com")
|
||||
req.Header.Set("X-No-Authorization", "true")
|
||||
req.Header.Set("X-No-User-Id", "true")
|
||||
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||
req.Header.Set("X-No-Department-Info", "true")
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
req.Header.Set("X-Request-ID", requestID)
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: auth state request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy auth state: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to read auth state response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("codebuddy: auth state request returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
State string `json:"state"`
|
||||
AuthURL string `json:"authUrl"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to parse auth state response: %w", err)
|
||||
}
|
||||
if result.Code != codeSuccess {
|
||||
return nil, fmt.Errorf("codebuddy: auth state request failed with code %d: %s", result.Code, result.Msg)
|
||||
}
|
||||
if result.Data == nil || result.Data.State == "" || result.Data.AuthURL == "" {
|
||||
return nil, fmt.Errorf("codebuddy: auth state response missing state or authUrl")
|
||||
}
|
||||
|
||||
return &AuthState{
|
||||
State: result.Data.State,
|
||||
AuthURL: result.Data.AuthURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type pollResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
RequestID string `json:"requestId"`
|
||||
Data *struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresIn int64 `json:"expiresIn"`
|
||||
TokenType string `json:"tokenType"`
|
||||
Domain string `json:"domain"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// doPollRequest performs a single polling request, safely reading and closing the response body
|
||||
func (a *CodeBuddyAuth) doPollRequest(ctx context.Context, pollURL string) ([]byte, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%w: %v", ErrTokenFetchFailed, err)
|
||||
}
|
||||
a.applyPollHeaders(req)
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy poll: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, fmt.Errorf("codebuddy poll: failed to read response body: %w", err)
|
||||
}
|
||||
return body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// PollForToken polls until the user completes browser authorization and returns auth data.
|
||||
func (a *CodeBuddyAuth) PollForToken(ctx context.Context, state string) (*CodeBuddyTokenStorage, error) {
|
||||
deadline := time.Now().Add(maxPollDuration)
|
||||
pollURL := fmt.Sprintf("%s%s?state=%s", a.baseURL, codeBuddyTokenPath, url.QueryEscape(state))
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(pollInterval):
|
||||
}
|
||||
|
||||
body, statusCode, err := a.doPollRequest(ctx, pollURL)
|
||||
if err != nil {
|
||||
log.Debugf("codebuddy poll: request error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if statusCode != http.StatusOK {
|
||||
log.Debugf("codebuddy poll: unexpected status %d", statusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
var result pollResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch result.Code {
|
||||
case codeSuccess:
|
||||
if result.Data == nil {
|
||||
return nil, fmt.Errorf("%w: empty data in response", ErrTokenFetchFailed)
|
||||
}
|
||||
userID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||
return &CodeBuddyTokenStorage{
|
||||
AccessToken: result.Data.AccessToken,
|
||||
RefreshToken: result.Data.RefreshToken,
|
||||
ExpiresIn: result.Data.ExpiresIn,
|
||||
TokenType: result.Data.TokenType,
|
||||
Domain: result.Data.Domain,
|
||||
UserID: userID,
|
||||
Type: "codebuddy",
|
||||
}, nil
|
||||
case codeLoginPending:
|
||||
// continue polling
|
||||
default:
|
||||
// TODO: when the CodeBuddy API error code for user denial is known,
|
||||
// return ErrAccessDenied here instead of ErrTokenFetchFailed.
|
||||
return nil, fmt.Errorf("%w: server returned code %d: %s", ErrTokenFetchFailed, result.Code, result.Msg)
|
||||
}
|
||||
}
|
||||
return nil, ErrPollingTimeout
|
||||
}
|
||||
|
||||
// DecodeUserID decodes the sub field from a JWT access token as the user ID.
|
||||
func (a *CodeBuddyAuth) DecodeUserID(accessToken string) (string, error) {
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) < 2 {
|
||||
return "", ErrJWTDecodeFailed
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||
}
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||
}
|
||||
if claims.Sub == "" {
|
||||
return "", fmt.Errorf("%w: sub claim is empty", ErrJWTDecodeFailed)
|
||||
}
|
||||
return claims.Sub, nil
|
||||
}
|
||||
|
||||
// RefreshToken exchanges a refresh token for a new access token.
|
||||
// It calls POST /v2/plugin/auth/token/refresh with the required headers.
|
||||
func (a *CodeBuddyAuth) RefreshToken(ctx context.Context, accessToken, refreshToken, userID, domain string) (*CodeBuddyTokenStorage, error) {
|
||||
if domain == "" {
|
||||
domain = DefaultDomain
|
||||
}
|
||||
refreshURL := fmt.Sprintf("%s%s", a.baseURL, codeBuddyRefreshPath)
|
||||
body := []byte("{}")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
requestID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
req.Header.Set("X-Domain", domain)
|
||||
req.Header.Set("X-Refresh-Token", refreshToken)
|
||||
req.Header.Set("X-Auth-Refresh-Source", "plugin")
|
||||
req.Header.Set("X-Request-ID", requestID)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("X-User-Id", userID)
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy refresh: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return nil, fmt.Errorf("codebuddy: refresh token rejected (status %d)", resp.StatusCode)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("codebuddy: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresIn int64 `json:"expiresIn"`
|
||||
RefreshExpiresIn int64 `json:"refreshExpiresIn"`
|
||||
TokenType string `json:"tokenType"`
|
||||
Domain string `json:"domain"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to parse refresh response: %w", err)
|
||||
}
|
||||
if result.Code != codeSuccess {
|
||||
return nil, fmt.Errorf("codebuddy: refresh failed with code %d: %s", result.Code, result.Msg)
|
||||
}
|
||||
if result.Data == nil {
|
||||
return nil, fmt.Errorf("codebuddy: empty data in refresh response")
|
||||
}
|
||||
|
||||
newUserID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||
if newUserID == "" {
|
||||
newUserID = userID
|
||||
}
|
||||
tokenDomain := result.Data.Domain
|
||||
if tokenDomain == "" {
|
||||
tokenDomain = domain
|
||||
}
|
||||
|
||||
return &CodeBuddyTokenStorage{
|
||||
AccessToken: result.Data.AccessToken,
|
||||
RefreshToken: result.Data.RefreshToken,
|
||||
ExpiresIn: result.Data.ExpiresIn,
|
||||
RefreshExpiresIn: result.Data.RefreshExpiresIn,
|
||||
TokenType: result.Data.TokenType,
|
||||
Domain: tokenDomain,
|
||||
UserID: newUserID,
|
||||
Type: "codebuddy",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *CodeBuddyAuth) applyPollHeaders(req *http.Request) {
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
req.Header.Set("X-No-Authorization", "true")
|
||||
req.Header.Set("X-No-User-Id", "true")
|
||||
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||
req.Header.Set("X-No-Department-Info", "true")
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
}
|
||||
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package codebuddy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// newTestAuth creates a CodeBuddyAuth pointing at the given test server.
|
||||
func newTestAuth(serverURL string) *CodeBuddyAuth {
|
||||
return &CodeBuddyAuth{
|
||||
httpClient: http.DefaultClient,
|
||||
baseURL: serverURL,
|
||||
}
|
||||
}
|
||||
|
||||
// fakeJWT builds a minimal JWT with the given sub claim for testing.
|
||||
func fakeJWT(sub string) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
|
||||
payload, _ := json.Marshal(map[string]any{"sub": sub, "iat": 1234567890})
|
||||
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||
return header + "." + encodedPayload + ".sig"
|
||||
}
|
||||
|
||||
// --- FetchAuthState tests ---
|
||||
|
||||
func TestFetchAuthState_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if got := r.URL.Path; got != codeBuddyStatePath {
|
||||
t.Errorf("expected path %s, got %s", codeBuddyStatePath, got)
|
||||
}
|
||||
if got := r.URL.Query().Get("platform"); got != "CLI" {
|
||||
t.Errorf("expected platform=CLI, got %s", got)
|
||||
}
|
||||
if got := r.Header.Get("User-Agent"); got != UserAgent {
|
||||
t.Errorf("expected User-Agent %s, got %s", UserAgent, got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"state": "test-state-abc",
|
||||
"authUrl": "https://example.com/login?state=test-state-abc",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
result, err := auth.FetchAuthState(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.State != "test-state-abc" {
|
||||
t.Errorf("expected state 'test-state-abc', got '%s'", result.State)
|
||||
}
|
||||
if result.AuthURL != "https://example.com/login?state=test-state-abc" {
|
||||
t.Errorf("unexpected authURL: %s", result.AuthURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAuthState_NonOKStatus(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("internal error"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.FetchAuthState(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-200 status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAuthState_APIErrorCode(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 10001,
|
||||
"msg": "rate limited",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.FetchAuthState(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-zero code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAuthState_MissingData(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"state": "",
|
||||
"authUrl": "",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.FetchAuthState(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty state/authUrl")
|
||||
}
|
||||
}
|
||||
|
||||
// --- RefreshToken tests ---
|
||||
|
||||
func TestRefreshToken_Success(t *testing.T) {
|
||||
newAccessToken := fakeJWT("refreshed-user-456")
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if got := r.URL.Path; got != codeBuddyRefreshPath {
|
||||
t.Errorf("expected path %s, got %s", codeBuddyRefreshPath, got)
|
||||
}
|
||||
if got := r.Header.Get("X-Refresh-Token"); got != "old-refresh-token" {
|
||||
t.Errorf("expected X-Refresh-Token 'old-refresh-token', got '%s'", got)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer old-access-token" {
|
||||
t.Errorf("expected Authorization 'Bearer old-access-token', got '%s'", got)
|
||||
}
|
||||
if got := r.Header.Get("X-User-Id"); got != "user-123" {
|
||||
t.Errorf("expected X-User-Id 'user-123', got '%s'", got)
|
||||
}
|
||||
if got := r.Header.Get("X-Domain"); got != "custom.domain.com" {
|
||||
t.Errorf("expected X-Domain 'custom.domain.com', got '%s'", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"accessToken": newAccessToken,
|
||||
"refreshToken": "new-refresh-token",
|
||||
"expiresIn": 3600,
|
||||
"refreshExpiresIn": 86400,
|
||||
"tokenType": "bearer",
|
||||
"domain": "custom.domain.com",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
storage, err := auth.RefreshToken(context.Background(), "old-access-token", "old-refresh-token", "user-123", "custom.domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if storage.AccessToken != newAccessToken {
|
||||
t.Errorf("expected new access token, got '%s'", storage.AccessToken)
|
||||
}
|
||||
if storage.RefreshToken != "new-refresh-token" {
|
||||
t.Errorf("expected 'new-refresh-token', got '%s'", storage.RefreshToken)
|
||||
}
|
||||
if storage.UserID != "refreshed-user-456" {
|
||||
t.Errorf("expected userID 'refreshed-user-456', got '%s'", storage.UserID)
|
||||
}
|
||||
if storage.ExpiresIn != 3600 {
|
||||
t.Errorf("expected expiresIn 3600, got %d", storage.ExpiresIn)
|
||||
}
|
||||
if storage.RefreshExpiresIn != 86400 {
|
||||
t.Errorf("expected refreshExpiresIn 86400, got %d", storage.RefreshExpiresIn)
|
||||
}
|
||||
if storage.Domain != "custom.domain.com" {
|
||||
t.Errorf("expected domain 'custom.domain.com', got '%s'", storage.Domain)
|
||||
}
|
||||
if storage.Type != "codebuddy" {
|
||||
t.Errorf("expected type 'codebuddy', got '%s'", storage.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_DefaultDomain(t *testing.T) {
|
||||
var receivedDomain string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedDomain = r.Header.Get("X-Domain")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"accessToken": fakeJWT("user-1"),
|
||||
"refreshToken": "rt",
|
||||
"expiresIn": 3600,
|
||||
"tokenType": "bearer",
|
||||
"domain": DefaultDomain,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if receivedDomain != DefaultDomain {
|
||||
t.Errorf("expected default domain '%s', got '%s'", DefaultDomain, receivedDomain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_Unauthorized(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 401 response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_Forbidden(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403 response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_APIErrorCode(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 40001,
|
||||
"msg": "invalid refresh token",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-zero API code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_FallbackUserIDAndDomain(t *testing.T) {
|
||||
// When the new access token cannot be decoded for userID, it should fall back to the provided one.
|
||||
// When the response domain is empty, it should fall back to the request domain.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"accessToken": "not-a-valid-jwt",
|
||||
"refreshToken": "new-rt",
|
||||
"expiresIn": 7200,
|
||||
"tokenType": "bearer",
|
||||
"domain": "",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
storage, err := auth.RefreshToken(context.Background(), "at", "rt", "original-uid", "original.domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if storage.UserID != "original-uid" {
|
||||
t.Errorf("expected fallback userID 'original-uid', got '%s'", storage.UserID)
|
||||
}
|
||||
if storage.Domain != "original.domain.com" {
|
||||
t.Errorf("expected fallback domain 'original.domain.com', got '%s'", storage.Domain)
|
||||
}
|
||||
}
|
||||
22
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
22
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package codebuddy_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||
)
|
||||
|
||||
func TestDecodeUserID_ValidJWT(t *testing.T) {
|
||||
// JWT payload: {"sub":"test-user-id-123","iat":1234567890}
|
||||
// base64url encode: eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ
|
||||
token := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ.sig"
|
||||
auth := codebuddy.NewCodeBuddyAuth(nil)
|
||||
userID, err := auth.DecodeUserID(token)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if userID != "test-user-id-123" {
|
||||
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
|
||||
}
|
||||
}
|
||||
|
||||
25
internal/auth/codebuddy/errors.go
Normal file
25
internal/auth/codebuddy/errors.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package codebuddy
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrPollingTimeout = errors.New("codebuddy: polling timeout, user did not authorize in time")
|
||||
ErrAccessDenied = errors.New("codebuddy: access denied by user")
|
||||
ErrTokenFetchFailed = errors.New("codebuddy: failed to fetch token from server")
|
||||
ErrJWTDecodeFailed = errors.New("codebuddy: failed to decode JWT token")
|
||||
)
|
||||
|
||||
func GetUserFriendlyMessage(err error) string {
|
||||
switch {
|
||||
case errors.Is(err, ErrPollingTimeout):
|
||||
return "Authentication timed out. Please try again."
|
||||
case errors.Is(err, ErrAccessDenied):
|
||||
return "Access denied. Please try again and approve the login request."
|
||||
case errors.Is(err, ErrJWTDecodeFailed):
|
||||
return "Failed to decode token. Please try logging in again."
|
||||
case errors.Is(err, ErrTokenFetchFailed):
|
||||
return "Failed to fetch token from server. Please try again."
|
||||
default:
|
||||
return "Authentication failed: " + err.Error()
|
||||
}
|
||||
}
|
||||
65
internal/auth/codebuddy/token.go
Normal file
65
internal/auth/codebuddy/token.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Package codebuddy provides authentication and token management functionality
|
||||
// for CodeBuddy AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the CodeBuddy API.
|
||||
package codebuddy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// CodeBuddyTokenStorage stores OAuth token information for CodeBuddy API authentication.
|
||||
// It maintains compatibility with the existing auth system while adding CodeBuddy-specific fields
|
||||
// for managing access tokens and user account information.
|
||||
type CodeBuddyTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// ExpiresIn is the number of seconds until the access token expires.
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
// RefreshExpiresIn is the number of seconds until the refresh token expires.
|
||||
RefreshExpiresIn int64 `json:"refresh_expires_in,omitempty"`
|
||||
// TokenType is the type of token, typically "bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Domain is the CodeBuddy service domain/region.
|
||||
Domain string `json:"domain"`
|
||||
// UserID is the user ID associated with this token.
|
||||
UserID string `json:"user_id"`
|
||||
// Type indicates the authentication provider type, always "codebuddy" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the CodeBuddy token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (s *CodeBuddyTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
s.Type = "codebuddy"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(authFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(s); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
33
internal/auth/cursor/filename.go
Normal file
33
internal/auth/cursor/filename.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Cursor credentials.
|
||||
// Priority: explicit label > auto-generated from JWT sub hash.
|
||||
// If both label and subHash are empty, falls back to "cursor.json".
|
||||
func CredentialFileName(label, subHash string) string {
|
||||
label = strings.TrimSpace(label)
|
||||
subHash = strings.TrimSpace(subHash)
|
||||
if label != "" {
|
||||
return fmt.Sprintf("cursor.%s.json", label)
|
||||
}
|
||||
if subHash != "" {
|
||||
return fmt.Sprintf("cursor.%s.json", subHash)
|
||||
}
|
||||
return "cursor.json"
|
||||
}
|
||||
|
||||
// DisplayLabel returns a human-readable label for the Cursor account.
|
||||
func DisplayLabel(label, subHash string) string {
|
||||
label = strings.TrimSpace(label)
|
||||
if label != "" {
|
||||
return "Cursor " + label
|
||||
}
|
||||
if subHash != "" {
|
||||
return "Cursor " + subHash
|
||||
}
|
||||
return "Cursor User"
|
||||
}
|
||||
249
internal/auth/cursor/oauth.go
Normal file
249
internal/auth/cursor/oauth.go
Normal file
@@ -0,0 +1,249 @@
|
||||
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CursorLoginURL = "https://cursor.com/loginDeepControl"
|
||||
CursorPollURL = "https://api2.cursor.sh/auth/poll"
|
||||
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
|
||||
|
||||
pollMaxAttempts = 150
|
||||
pollBaseDelay = 1 * time.Second
|
||||
pollMaxDelay = 10 * time.Second
|
||||
pollBackoffMultiply = 1.2
|
||||
maxConsecutiveErrors = 10
|
||||
)
|
||||
|
||||
// AuthParams holds the PKCE parameters for Cursor login.
|
||||
type AuthParams struct {
|
||||
Verifier string
|
||||
Challenge string
|
||||
UUID string
|
||||
LoginURL string
|
||||
}
|
||||
|
||||
// TokenPair holds the access and refresh tokens from Cursor.
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// GeneratePKCE creates a PKCE verifier and challenge pair.
|
||||
func GeneratePKCE() (verifier, challenge string, err error) {
|
||||
verifierBytes := make([]byte, 96)
|
||||
if _, err = rand.Read(verifierBytes); err != nil {
|
||||
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// GenerateAuthParams creates the full set of auth params for Cursor login.
|
||||
func GenerateAuthParams() (*AuthParams, error) {
|
||||
verifier, challenge, err := GeneratePKCE()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uuidBytes := make([]byte, 16)
|
||||
if _, err = rand.Read(uuidBytes); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
|
||||
}
|
||||
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
|
||||
|
||||
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
|
||||
CursorLoginURL, challenge, uuid)
|
||||
|
||||
return &AuthParams{
|
||||
Verifier: verifier,
|
||||
Challenge: challenge,
|
||||
UUID: uuid,
|
||||
LoginURL: loginURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PollForAuth polls the Cursor auth endpoint until the user completes login.
|
||||
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
|
||||
delay := pollBaseDelay
|
||||
consecutiveErrors := 0
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
consecutiveErrors++
|
||||
if consecutiveErrors >= maxConsecutiveErrors {
|
||||
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
|
||||
}
|
||||
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// Still waiting for user to authorize
|
||||
consecutiveErrors = 0
|
||||
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
var tokens TokenPair
|
||||
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
|
||||
}
|
||||
return &tokens, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
|
||||
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
|
||||
}
|
||||
|
||||
// RefreshToken refreshes a Cursor access token using the refresh token.
|
||||
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
|
||||
strings.NewReader("{}"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+refreshToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokens TokenPair
|
||||
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Keep original refresh token if not returned
|
||||
if tokens.RefreshToken == "" {
|
||||
tokens.RefreshToken = refreshToken
|
||||
}
|
||||
|
||||
return &tokens, nil
|
||||
}
|
||||
|
||||
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
|
||||
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
|
||||
// the account. Returns empty string if parsing fails.
|
||||
func ParseJWTSub(token string) string {
|
||||
decoded := decodeJWTPayload(token)
|
||||
if decoded == nil {
|
||||
return ""
|
||||
}
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
return claims.Sub
|
||||
}
|
||||
|
||||
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
|
||||
// e.g. "auth0|user_2x..." → "a3f8b2c1"
|
||||
func SubToShortHash(sub string) string {
|
||||
if sub == "" {
|
||||
return ""
|
||||
}
|
||||
h := sha256.Sum256([]byte(sub))
|
||||
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
|
||||
}
|
||||
|
||||
// decodeJWTPayload decodes the payload (middle) part of a JWT.
|
||||
func decodeJWTPayload(token string) []byte {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil
|
||||
}
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
payload = strings.ReplaceAll(payload, "-", "+")
|
||||
payload = strings.ReplaceAll(payload, "_", "/")
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return decoded
|
||||
}
|
||||
|
||||
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
|
||||
// Falls back to 1 hour from now if the token can't be parsed.
|
||||
func GetTokenExpiry(token string) time.Time {
|
||||
decoded := decodeJWTPayload(token)
|
||||
if decoded == nil {
|
||||
return time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Exp float64 `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
|
||||
return time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
sec, frac := math.Modf(claims.Exp)
|
||||
expiry := time.Unix(int64(sec), int64(frac*1e9))
|
||||
// Subtract 5-minute safety margin
|
||||
return expiry.Add(-5 * time.Minute)
|
||||
}
|
||||
|
||||
func minDuration(a, b time.Duration) time.Duration {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
84
internal/auth/cursor/proto/connect.go
Normal file
84
internal/auth/cursor/proto/connect.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
|
||||
ConnectEndStreamFlag byte = 0x02
|
||||
// ConnectCompressionFlag indicates the payload is compressed (not supported).
|
||||
ConnectCompressionFlag byte = 0x01
|
||||
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
|
||||
ConnectFrameHeaderSize = 5
|
||||
)
|
||||
|
||||
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
|
||||
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
|
||||
func FrameConnectMessage(data []byte, flags byte) []byte {
|
||||
frame := make([]byte, ConnectFrameHeaderSize+len(data))
|
||||
frame[0] = flags
|
||||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
|
||||
copy(frame[5:], data)
|
||||
return frame
|
||||
}
|
||||
|
||||
// ParseConnectFrame extracts one frame from a buffer.
|
||||
// Returns (flags, payload, bytesConsumed, ok).
|
||||
// ok is false when the buffer is too short for a complete frame.
|
||||
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
|
||||
if len(buf) < ConnectFrameHeaderSize {
|
||||
return 0, nil, 0, false
|
||||
}
|
||||
flags = buf[0]
|
||||
length := binary.BigEndian.Uint32(buf[1:5])
|
||||
total := ConnectFrameHeaderSize + int(length)
|
||||
if len(buf) < total {
|
||||
return 0, nil, 0, false
|
||||
}
|
||||
return flags, buf[5:total], total, true
|
||||
}
|
||||
|
||||
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
|
||||
// The Code field contains the server-defined error code (e.g. gRPC standard codes
|
||||
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
|
||||
type ConnectError struct {
|
||||
Code string // server-defined error code
|
||||
Message string // human-readable error description
|
||||
}
|
||||
|
||||
func (e *ConnectError) Error() string {
|
||||
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
|
||||
// Returns nil if there is no error in the trailer.
|
||||
// On error, returns a *ConnectError with the server's error code and message.
|
||||
func ParseConnectEndStream(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var trailer struct {
|
||||
Error *struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &trailer); err != nil {
|
||||
return fmt.Errorf("failed to parse Connect end stream: %w", err)
|
||||
}
|
||||
if trailer.Error != nil {
|
||||
code := trailer.Error.Code
|
||||
if code == "" {
|
||||
code = "unknown"
|
||||
}
|
||||
msg := trailer.Error.Message
|
||||
if msg == "" {
|
||||
msg = "Unknown error"
|
||||
}
|
||||
return &ConnectError{Code: code, Message: msg}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
564
internal/auth/cursor/proto/decode.go
Normal file
564
internal/auth/cursor/proto/decode.go
Normal file
@@ -0,0 +1,564 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
// ServerMessageType identifies the kind of decoded server message.
|
||||
type ServerMessageType int
|
||||
|
||||
const (
|
||||
ServerMsgUnknown ServerMessageType = iota
|
||||
ServerMsgTextDelta // Text content delta
|
||||
ServerMsgThinkingDelta // Thinking/reasoning delta
|
||||
ServerMsgThinkingCompleted // Thinking completed
|
||||
ServerMsgKvGetBlob // Server wants a blob
|
||||
ServerMsgKvSetBlob // Server wants to store a blob
|
||||
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
|
||||
ServerMsgExecMcpArgs // Server wants MCP tool execution
|
||||
ServerMsgExecShellArgs // Rejected: shell command
|
||||
ServerMsgExecReadArgs // Rejected: file read
|
||||
ServerMsgExecWriteArgs // Rejected: file write
|
||||
ServerMsgExecDeleteArgs // Rejected: file delete
|
||||
ServerMsgExecLsArgs // Rejected: directory listing
|
||||
ServerMsgExecGrepArgs // Rejected: grep search
|
||||
ServerMsgExecFetchArgs // Rejected: HTTP fetch
|
||||
ServerMsgExecDiagnostics // Respond with empty diagnostics
|
||||
ServerMsgExecShellStream // Rejected: shell stream
|
||||
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||
ServerMsgExecOther // Other exec types (respond with empty)
|
||||
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||
ServerMsgHeartbeat // Server heartbeat
|
||||
ServerMsgTokenDelta // Token usage delta
|
||||
ServerMsgCheckpoint // Conversation checkpoint update
|
||||
)
|
||||
|
||||
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||
type DecodedServerMessage struct {
|
||||
Type ServerMessageType
|
||||
|
||||
// For text/thinking deltas
|
||||
Text string
|
||||
|
||||
// For KV messages
|
||||
KvId uint32
|
||||
BlobId []byte // hex-encoded blob ID
|
||||
BlobData []byte // for setBlobArgs
|
||||
|
||||
// For exec messages
|
||||
ExecMsgId uint32
|
||||
ExecId string
|
||||
|
||||
// For MCP args
|
||||
McpToolName string
|
||||
McpToolCallId string
|
||||
McpArgs map[string][]byte // arg name -> protobuf-encoded value
|
||||
|
||||
// For rejection context
|
||||
Path string
|
||||
Command string
|
||||
WorkingDirectory string
|
||||
Url string
|
||||
|
||||
// For other exec - the raw field number for building a response
|
||||
ExecFieldNumber int
|
||||
|
||||
// For TokenDeltaUpdate
|
||||
TokenDelta int64
|
||||
|
||||
// For conversation checkpoint update (raw bytes, not decoded)
|
||||
CheckpointData []byte
|
||||
}
|
||||
|
||||
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
||||
// a structured representation of the first meaningful message found.
|
||||
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
|
||||
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
|
||||
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid tag")
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid bytes field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
// Debug: log top-level ASM fields
|
||||
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
|
||||
|
||||
switch num {
|
||||
case ASM_InteractionUpdate:
|
||||
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
|
||||
decodeInteractionUpdate(val, msg)
|
||||
case ASM_ExecServerMessage:
|
||||
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
|
||||
decodeExecServerMessage(val, msg)
|
||||
case ASM_KvServerMessage:
|
||||
decodeKvServerMessage(val, msg)
|
||||
case ASM_ConversationCheckpoint:
|
||||
msg.Type = ServerMsgCheckpoint
|
||||
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
|
||||
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
|
||||
}
|
||||
|
||||
case protowire.VarintType:
|
||||
_, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid varint field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
default:
|
||||
// Skip unknown wire types
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return msg, fmt.Errorf("invalid field %d", num)
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
|
||||
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||
|
||||
switch num {
|
||||
case IU_TextDelta:
|
||||
msg.Type = ServerMsgTextDelta
|
||||
msg.Text = decodeStringField(val, TDU_Text)
|
||||
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
|
||||
case IU_ThinkingDelta:
|
||||
msg.Type = ServerMsgThinkingDelta
|
||||
msg.Text = decodeStringField(val, TKD_Text)
|
||||
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
|
||||
case IU_ThinkingCompleted:
|
||||
msg.Type = ServerMsgThinkingCompleted
|
||||
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
|
||||
case 2:
|
||||
// tool_call_started - ignore but log
|
||||
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
|
||||
case 3:
|
||||
// tool_call_completed - ignore but log
|
||||
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
||||
case 8:
|
||||
// token_delta - extract token count
|
||||
msg.Type = ServerMsgTokenDelta
|
||||
msg.TokenDelta = decodeVarintField(val, 1)
|
||||
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
|
||||
case 13:
|
||||
// heartbeat from server
|
||||
msg.Type = ServerMsgHeartbeat
|
||||
case 14:
|
||||
// turn_ended - critical: model finished generating
|
||||
msg.Type = ServerMsgTurnEnded
|
||||
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
|
||||
case 16:
|
||||
// step_started - ignore
|
||||
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
|
||||
case 17:
|
||||
// step_completed - ignore
|
||||
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
|
||||
default:
|
||||
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.VarintType:
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == KSM_Id {
|
||||
msg.KvId = uint32(val)
|
||||
}
|
||||
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch num {
|
||||
case KSM_GetBlobArgs:
|
||||
msg.Type = ServerMsgKvGetBlob
|
||||
msg.BlobId = decodeBytesField(val, GBA_BlobId)
|
||||
case KSM_SetBlobArgs:
|
||||
msg.Type = ServerMsgKvSetBlob
|
||||
decodeSetBlobArgs(val, msg)
|
||||
}
|
||||
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
switch num {
|
||||
case SBA_BlobId:
|
||||
msg.BlobId = val
|
||||
case SBA_BlobData:
|
||||
msg.BlobData = val
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch typ {
|
||||
case protowire.VarintType:
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == ESM_Id {
|
||||
msg.ExecMsgId = uint32(val)
|
||||
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
|
||||
}
|
||||
|
||||
case protowire.BytesType:
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
// Debug: log all fields found in ExecServerMessage
|
||||
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||
|
||||
switch num {
|
||||
case ESM_ExecId:
|
||||
msg.ExecId = string(val)
|
||||
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
|
||||
case ESM_RequestContextArgs:
|
||||
msg.Type = ServerMsgExecRequestCtx
|
||||
case ESM_McpArgs:
|
||||
msg.Type = ServerMsgExecMcpArgs
|
||||
decodeMcpArgs(val, msg)
|
||||
case ESM_ShellArgs:
|
||||
msg.Type = ServerMsgExecShellArgs
|
||||
decodeShellArgs(val, msg)
|
||||
case ESM_ShellStreamArgs:
|
||||
msg.Type = ServerMsgExecShellStream
|
||||
decodeShellArgs(val, msg)
|
||||
case ESM_ReadArgs:
|
||||
msg.Type = ServerMsgExecReadArgs
|
||||
msg.Path = decodeStringField(val, RA_Path)
|
||||
case ESM_WriteArgs:
|
||||
msg.Type = ServerMsgExecWriteArgs
|
||||
msg.Path = decodeStringField(val, WA_Path)
|
||||
case ESM_DeleteArgs:
|
||||
msg.Type = ServerMsgExecDeleteArgs
|
||||
msg.Path = decodeStringField(val, DA_Path)
|
||||
case ESM_LsArgs:
|
||||
msg.Type = ServerMsgExecLsArgs
|
||||
msg.Path = decodeStringField(val, LA_Path)
|
||||
case ESM_GrepArgs:
|
||||
msg.Type = ServerMsgExecGrepArgs
|
||||
case ESM_FetchArgs:
|
||||
msg.Type = ServerMsgExecFetchArgs
|
||||
msg.Url = decodeStringField(val, FA_Url)
|
||||
case ESM_DiagnosticsArgs:
|
||||
msg.Type = ServerMsgExecDiagnostics
|
||||
case ESM_BackgroundShellSpawn:
|
||||
msg.Type = ServerMsgExecBgShellSpawn
|
||||
decodeShellArgs(val, msg) // same structure
|
||||
case ESM_WriteShellStdinArgs:
|
||||
msg.Type = ServerMsgExecWriteShellStdin
|
||||
default:
|
||||
// Unknown exec types - only set if we haven't identified the type yet
|
||||
// (other fields like span_context (19) come after the exec type field)
|
||||
if msg.Type == ServerMsgUnknown {
|
||||
msg.Type = ServerMsgExecOther
|
||||
msg.ExecFieldNumber = int(num)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
|
||||
msg.McpArgs = make(map[string][]byte)
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
switch num {
|
||||
case MCA_Name:
|
||||
msg.McpToolName = string(val)
|
||||
case MCA_Args:
|
||||
// Map entries are encoded as submessages with key=1, value=2
|
||||
decodeMapEntry(val, msg.McpArgs)
|
||||
case MCA_ToolCallId:
|
||||
msg.McpToolCallId = string(val)
|
||||
case MCA_ToolName:
|
||||
// ToolName takes precedence if present
|
||||
if msg.McpToolName == "" || string(val) != "" {
|
||||
msg.McpToolName = string(val)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeMapEntry(data []byte, m map[string][]byte) {
|
||||
var key string
|
||||
var value []byte
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
if num == 1 {
|
||||
key = string(val)
|
||||
} else if num == 2 {
|
||||
value = append([]byte(nil), val...)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
if key != "" {
|
||||
m[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
switch num {
|
||||
case SHA_Command:
|
||||
msg.Command = string(val)
|
||||
case SHA_WorkingDirectory:
|
||||
msg.WorkingDirectory = string(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper decoders ---
|
||||
|
||||
// decodeStringField extracts a string from the first matching field in a submessage.
|
||||
func decodeStringField(data []byte, targetField protowire.Number) string {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return string(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// decodeBytesField extracts bytes from the first matching field in a submessage.
|
||||
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
|
||||
if typ == protowire.BytesType {
|
||||
val, n := protowire.ConsumeBytes(data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return append([]byte(nil), val...)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return nil
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
|
||||
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||
for len(data) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
if typ == protowire.VarintType {
|
||||
val, n := protowire.ConsumeVarint(data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
if num == targetField {
|
||||
return int64(val)
|
||||
}
|
||||
} else {
|
||||
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
||||
func BlobIdHex(blobId []byte) string {
|
||||
return hex.EncodeToString(blobId)
|
||||
}
|
||||
|
||||
1244
internal/auth/cursor/proto/descriptor.go
Normal file
1244
internal/auth/cursor/proto/descriptor.go
Normal file
File diff suppressed because it is too large
Load Diff
664
internal/auth/cursor/proto/encode.go
Normal file
664
internal/auth/cursor/proto/encode.go
Normal file
@@ -0,0 +1,664 @@
|
||||
// Package proto provides protobuf encoding for Cursor's gRPC API,
|
||||
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
|
||||
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
|
||||
package proto
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/dynamicpb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
// --- Public types ---
|
||||
|
||||
// RunRequestParams holds all data needed to build an AgentRunRequest.
|
||||
type RunRequestParams struct {
|
||||
ModelId string
|
||||
SystemPrompt string
|
||||
UserText string
|
||||
MessageId string
|
||||
ConversationId string
|
||||
Images []ImageData
|
||||
Turns []TurnData
|
||||
McpTools []McpToolDef
|
||||
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
|
||||
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
MimeType string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type TurnData struct {
|
||||
UserText string
|
||||
AssistantText string
|
||||
}
|
||||
|
||||
type McpToolDef struct {
|
||||
Name string
|
||||
Description string
|
||||
InputSchema json.RawMessage
|
||||
}
|
||||
|
||||
// --- Helper: create a dynamic message and set fields ---
|
||||
|
||||
func newMsg(name string) *dynamicpb.Message {
|
||||
return dynamicpb.NewMessage(Msg(name))
|
||||
}
|
||||
|
||||
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
|
||||
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
|
||||
}
|
||||
|
||||
func setStr(msg *dynamicpb.Message, name, val string) {
|
||||
if val != "" {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
|
||||
}
|
||||
}
|
||||
|
||||
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
|
||||
if len(val) > 0 {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
|
||||
}
|
||||
}
|
||||
|
||||
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
|
||||
}
|
||||
|
||||
func setBool(msg *dynamicpb.Message, name string, val bool) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
|
||||
}
|
||||
|
||||
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
|
||||
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
|
||||
}
|
||||
|
||||
func marshal(msg *dynamicpb.Message) []byte {
|
||||
b, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
panic("cursor proto marshal: " + err.Error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// --- Encode functions mirroring cursor-fetch.ts ---
|
||||
|
||||
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
|
||||
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
|
||||
func EncodeHeartbeat() []byte {
|
||||
hb := newMsg("ClientHeartbeat")
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "client_heartbeat", hb)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
|
||||
// Mirrors buildCursorRequest() in cursor-fetch.ts.
|
||||
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
|
||||
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
|
||||
func EncodeRunRequest(p *RunRequestParams) []byte {
|
||||
if p.RawCheckpoint != nil {
|
||||
return encodeRunRequestWithCheckpoint(p)
|
||||
}
|
||||
|
||||
if p.BlobStore == nil {
|
||||
p.BlobStore = make(map[string][]byte)
|
||||
}
|
||||
|
||||
// --- Conversation turns ---
|
||||
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
|
||||
var turnBytes [][]byte
|
||||
for _, turn := range p.Turns {
|
||||
// UserMessage for this turn
|
||||
um := newMsg("UserMessage")
|
||||
setStr(um, "text", turn.UserText)
|
||||
setStr(um, "message_id", generateId())
|
||||
umBytes := marshal(um)
|
||||
|
||||
// Steps (assistant response)
|
||||
var stepBytes [][]byte
|
||||
if turn.AssistantText != "" {
|
||||
am := newMsg("AssistantMessage")
|
||||
setStr(am, "text", turn.AssistantText)
|
||||
step := newMsg("ConversationStep")
|
||||
setMsg(step, "assistant_message", am)
|
||||
stepBytes = append(stepBytes, marshal(step))
|
||||
}
|
||||
|
||||
// AgentConversationTurnStructure (fields are bytes, not submessages)
|
||||
agentTurn := newMsg("AgentConversationTurnStructure")
|
||||
setBytes(agentTurn, "user_message", umBytes)
|
||||
for _, sb := range stepBytes {
|
||||
stepsField := field(agentTurn, "steps")
|
||||
list := agentTurn.Mutable(stepsField).List()
|
||||
list.Append(protoreflect.ValueOfBytes(sb))
|
||||
}
|
||||
|
||||
// ConversationTurnStructure (oneof turn → agentConversationTurn)
|
||||
cts := newMsg("ConversationTurnStructure")
|
||||
setMsg(cts, "agent_conversation_turn", agentTurn)
|
||||
turnBytes = append(turnBytes, marshal(cts))
|
||||
}
|
||||
|
||||
// --- System prompt blob ---
|
||||
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
|
||||
blobId := sha256Sum(systemJSON)
|
||||
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
|
||||
|
||||
// --- ConversationStateStructure ---
|
||||
css := newMsg("ConversationStateStructure")
|
||||
// rootPromptMessagesJson: repeated bytes
|
||||
rootField := field(css, "root_prompt_messages_json")
|
||||
rootList := css.Mutable(rootField).List()
|
||||
rootList.Append(protoreflect.ValueOfBytes(blobId))
|
||||
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
|
||||
turnsField := field(css, "turns")
|
||||
turnsList := css.Mutable(turnsField).List()
|
||||
for _, tb := range turnBytes {
|
||||
turnsList.Append(protoreflect.ValueOfBytes(tb))
|
||||
}
|
||||
turnsOldField := field(css, "turns_old")
|
||||
if turnsOldField != nil {
|
||||
turnsOldList := css.Mutable(turnsOldField).List()
|
||||
for _, tb := range turnBytes {
|
||||
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
|
||||
}
|
||||
}
|
||||
|
||||
// --- UserMessage (current) ---
|
||||
userMessage := newMsg("UserMessage")
|
||||
setStr(userMessage, "text", p.UserText)
|
||||
setStr(userMessage, "message_id", p.MessageId)
|
||||
|
||||
// Images via SelectedContext
|
||||
if len(p.Images) > 0 {
|
||||
sc := newMsg("SelectedContext")
|
||||
imgsField := field(sc, "selected_images")
|
||||
imgsList := sc.Mutable(imgsField).List()
|
||||
for _, img := range p.Images {
|
||||
si := newMsg("SelectedImage")
|
||||
setStr(si, "uuid", generateId())
|
||||
setStr(si, "mime_type", img.MimeType)
|
||||
setBytes(si, "data", img.Data)
|
||||
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||
}
|
||||
setMsg(userMessage, "selected_context", sc)
|
||||
}
|
||||
|
||||
// --- UserMessageAction ---
|
||||
uma := newMsg("UserMessageAction")
|
||||
setMsg(uma, "user_message", userMessage)
|
||||
|
||||
// --- ConversationAction ---
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "user_message_action", uma)
|
||||
|
||||
// --- ModelDetails ---
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
|
||||
// --- AgentRunRequest ---
|
||||
arr := newMsg("AgentRunRequest")
|
||||
setMsg(arr, "conversation_state", css)
|
||||
setMsg(arr, "action", ca)
|
||||
setMsg(arr, "model_details", md)
|
||||
setStr(arr, "conversation_id", p.ConversationId)
|
||||
|
||||
// McpTools
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
setMsg(arr, "mcp_tools", mcpTools)
|
||||
}
|
||||
|
||||
// --- AgentClientMessage ---
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "run_request", arr)
|
||||
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
|
||||
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
|
||||
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
|
||||
// Build UserMessage
|
||||
userMessage := newMsg("UserMessage")
|
||||
setStr(userMessage, "text", p.UserText)
|
||||
setStr(userMessage, "message_id", p.MessageId)
|
||||
if len(p.Images) > 0 {
|
||||
sc := newMsg("SelectedContext")
|
||||
imgsField := field(sc, "selected_images")
|
||||
imgsList := sc.Mutable(imgsField).List()
|
||||
for _, img := range p.Images {
|
||||
si := newMsg("SelectedImage")
|
||||
setStr(si, "uuid", generateId())
|
||||
setStr(si, "mime_type", img.MimeType)
|
||||
setBytes(si, "data", img.Data)
|
||||
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||
}
|
||||
setMsg(userMessage, "selected_context", sc)
|
||||
}
|
||||
|
||||
// Build ConversationAction with UserMessageAction
|
||||
uma := newMsg("UserMessageAction")
|
||||
setMsg(uma, "user_message", userMessage)
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "user_message_action", uma)
|
||||
caBytes := marshal(ca)
|
||||
|
||||
// Build ModelDetails
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
mdBytes := marshal(md)
|
||||
|
||||
// Build McpTools
|
||||
var mcpToolsBytes []byte
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
mcpToolsBytes = marshal(mcpTools)
|
||||
}
|
||||
|
||||
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
|
||||
var arrBuf []byte
|
||||
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
|
||||
// field 2: action = ConversationAction
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
|
||||
// field 3: model_details = ModelDetails
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
|
||||
// field 4: mcp_tools = McpTools
|
||||
if len(mcpToolsBytes) > 0 {
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
|
||||
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
|
||||
}
|
||||
// field 5: conversation_id = string
|
||||
if p.ConversationId != "" {
|
||||
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
|
||||
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
|
||||
}
|
||||
|
||||
// Wrap in AgentClientMessage field 1 (run_request)
|
||||
var acmBuf []byte
|
||||
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
|
||||
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
|
||||
|
||||
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
|
||||
return acmBuf
|
||||
}
|
||||
|
||||
// ResumeRequestParams holds data for a ResumeAction request.
|
||||
type ResumeRequestParams struct {
|
||||
ModelId string
|
||||
ConversationId string
|
||||
McpTools []McpToolDef
|
||||
}
|
||||
|
||||
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
|
||||
// Used to resume a conversation by conversation_id without re-sending full history.
|
||||
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
|
||||
// RequestContext with tools
|
||||
rc := newMsg("RequestContext")
|
||||
if len(p.McpTools) > 0 {
|
||||
toolsField := field(rc, "tools")
|
||||
toolsList := rc.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
}
|
||||
|
||||
// ResumeAction
|
||||
ra := newMsg("ResumeAction")
|
||||
setMsg(ra, "request_context", rc)
|
||||
|
||||
// ConversationAction with resume_action
|
||||
ca := newMsg("ConversationAction")
|
||||
setMsg(ca, "resume_action", ra)
|
||||
|
||||
// ModelDetails
|
||||
md := newMsg("ModelDetails")
|
||||
setStr(md, "model_id", p.ModelId)
|
||||
setStr(md, "display_model_id", p.ModelId)
|
||||
setStr(md, "display_name", p.ModelId)
|
||||
|
||||
// AgentRunRequest — no conversation_state needed for resume
|
||||
arr := newMsg("AgentRunRequest")
|
||||
setMsg(arr, "action", ca)
|
||||
setMsg(arr, "model_details", md)
|
||||
setStr(arr, "conversation_id", p.ConversationId)
|
||||
|
||||
// McpTools at top level
|
||||
if len(p.McpTools) > 0 {
|
||||
mcpTools := newMsg("McpTools")
|
||||
toolsField := field(mcpTools, "mcp_tools")
|
||||
toolsList := mcpTools.Mutable(toolsField).List()
|
||||
for _, tool := range p.McpTools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
setMsg(arr, "mcp_tools", mcpTools)
|
||||
}
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "run_request", arr)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// --- KV response encoders ---
|
||||
// Mirrors handleKvMessage() in cursor-fetch.ts
|
||||
|
||||
// EncodeKvGetBlobResult responds to a getBlobArgs request.
|
||||
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
|
||||
result := newMsg("GetBlobResult")
|
||||
if blobData != nil {
|
||||
setBytes(result, "blob_data", blobData)
|
||||
}
|
||||
|
||||
kvc := newMsg("KvClientMessage")
|
||||
setUint32(kvc, "id", kvId)
|
||||
setMsg(kvc, "get_blob_result", result)
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "kv_client_message", kvc)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// EncodeKvSetBlobResult responds to a setBlobArgs request.
|
||||
func EncodeKvSetBlobResult(kvId uint32) []byte {
|
||||
result := newMsg("SetBlobResult")
|
||||
|
||||
kvc := newMsg("KvClientMessage")
|
||||
setUint32(kvc, "id", kvId)
|
||||
setMsg(kvc, "set_blob_result", result)
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "kv_client_message", kvc)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
// --- Exec response encoders ---
|
||||
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
|
||||
|
||||
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
|
||||
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
|
||||
// RequestContext with tools
|
||||
rc := newMsg("RequestContext")
|
||||
if len(tools) > 0 {
|
||||
toolsField := field(rc, "tools")
|
||||
toolsList := rc.Mutable(toolsField).List()
|
||||
for _, tool := range tools {
|
||||
td := newMsg("McpToolDefinition")
|
||||
setStr(td, "name", tool.Name)
|
||||
setStr(td, "description", tool.Description)
|
||||
if len(tool.InputSchema) > 0 {
|
||||
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||
}
|
||||
setStr(td, "provider_identifier", "proxy")
|
||||
setStr(td, "tool_name", tool.Name)
|
||||
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||
}
|
||||
}
|
||||
|
||||
// RequestContextSuccess
|
||||
rcs := newMsg("RequestContextSuccess")
|
||||
setMsg(rcs, "request_context", rc)
|
||||
|
||||
// RequestContextResult (oneof success)
|
||||
rcr := newMsg("RequestContextResult")
|
||||
setMsg(rcr, "success", rcs)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
|
||||
}
|
||||
|
||||
// EncodeExecMcpResult responds with MCP tool result.
|
||||
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
|
||||
textContent := newMsg("McpTextContent")
|
||||
setStr(textContent, "text", content)
|
||||
|
||||
contentItem := newMsg("McpToolResultContentItem")
|
||||
setMsg(contentItem, "text", textContent)
|
||||
|
||||
success := newMsg("McpSuccess")
|
||||
contentField := field(success, "content")
|
||||
contentList := success.Mutable(contentField).List()
|
||||
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
|
||||
setBool(success, "is_error", isError)
|
||||
|
||||
result := newMsg("McpResult")
|
||||
setMsg(result, "success", success)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||
}
|
||||
|
||||
// EncodeExecMcpError responds with MCP error.
|
||||
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
mcpErr := newMsg("McpError")
|
||||
setStr(mcpErr, "error", errMsg)
|
||||
|
||||
result := newMsg("McpResult")
|
||||
setMsg(result, "error", mcpErr)
|
||||
|
||||
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||
}
|
||||
|
||||
// --- Rejection encoders (mirror handleExecMessage rejections) ---
|
||||
|
||||
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("ReadRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("ReadResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||
rej := newMsg("ShellRejected")
|
||||
setStr(rej, "command", command)
|
||||
setStr(rej, "working_directory", workDir)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("ShellResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("WriteRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("WriteResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("DeleteRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("DeleteResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||
rej := newMsg("LsRejected")
|
||||
setStr(rej, "path", path)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("LsResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
grepErr := newMsg("GrepError")
|
||||
setStr(grepErr, "error", errMsg)
|
||||
result := newMsg("GrepResult")
|
||||
setMsg(result, "error", grepErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
|
||||
fetchErr := newMsg("FetchError")
|
||||
setStr(fetchErr, "url", url)
|
||||
setStr(fetchErr, "error", errMsg)
|
||||
result := newMsg("FetchResult")
|
||||
setMsg(result, "error", fetchErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
|
||||
result := newMsg("DiagnosticsResult")
|
||||
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||
rej := newMsg("ShellRejected")
|
||||
setStr(rej, "command", command)
|
||||
setStr(rej, "working_directory", workDir)
|
||||
setStr(rej, "reason", reason)
|
||||
result := newMsg("BackgroundShellSpawnResult")
|
||||
setMsg(result, "rejected", rej)
|
||||
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
|
||||
}
|
||||
|
||||
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||
wsErr := newMsg("WriteShellStdinError")
|
||||
setStr(wsErr, "error", errMsg)
|
||||
result := newMsg("WriteShellStdinResult")
|
||||
setMsg(result, "error", wsErr)
|
||||
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
|
||||
}
|
||||
|
||||
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
|
||||
// Mirrors sendExec() in cursor-fetch.ts.
|
||||
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
|
||||
ecm := newMsg("ExecClientMessage")
|
||||
setUint32(ecm, "id", id)
|
||||
// Force set exec_id even if empty - Cursor requires this field to be set
|
||||
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
|
||||
|
||||
// Debug: check if field exists
|
||||
fd := field(ecm, resultFieldName)
|
||||
if fd == nil {
|
||||
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
|
||||
}
|
||||
|
||||
// Debug: log the actual field being set
|
||||
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
|
||||
|
||||
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
|
||||
|
||||
acm := newMsg("AgentClientMessage")
|
||||
setMsg(acm, "exec_client_message", ecm)
|
||||
return marshal(acm)
|
||||
}
|
||||
|
||||
func listFields(msg *dynamicpb.Message) []string {
|
||||
var names []string
|
||||
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
|
||||
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// --- Utilities ---
|
||||
|
||||
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
|
||||
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
|
||||
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
|
||||
if len(jsonData) == 0 {
|
||||
return nil
|
||||
}
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(jsonData, &v); err != nil {
|
||||
return jsonData // fallback to raw JSON if parsing fails
|
||||
}
|
||||
pbVal, err := structpb.NewValue(v)
|
||||
if err != nil {
|
||||
return jsonData // fallback
|
||||
}
|
||||
b, err := proto.Marshal(pbVal)
|
||||
if err != nil {
|
||||
return jsonData // fallback
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
|
||||
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
|
||||
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
|
||||
val := &structpb.Value{}
|
||||
if err := proto.Unmarshal(data, val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return val.AsInterface(), nil
|
||||
}
|
||||
|
||||
func sha256Sum(data []byte) []byte {
|
||||
h := sha256.Sum256(data)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
var idCounter uint64
|
||||
|
||||
func generateId() string {
|
||||
idCounter++
|
||||
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
|
||||
return hex.EncodeToString(h[:16])
|
||||
}
|
||||
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
@@ -0,0 +1,332 @@
|
||||
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
|
||||
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
|
||||
package proto
|
||||
|
||||
// AgentClientMessage (msg 118) oneof "message"
|
||||
const (
|
||||
ACM_RunRequest = 1 // AgentRunRequest
|
||||
ACM_ExecClientMessage = 2 // ExecClientMessage
|
||||
ACM_KvClientMessage = 3 // KvClientMessage
|
||||
ACM_ConversationAction = 4 // ConversationAction
|
||||
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
|
||||
ACM_InteractionResponse = 6 // InteractionResponse
|
||||
ACM_ClientHeartbeat = 7 // ClientHeartbeat
|
||||
)
|
||||
|
||||
// AgentServerMessage (msg 119) oneof "message"
|
||||
const (
|
||||
ASM_InteractionUpdate = 1 // InteractionUpdate
|
||||
ASM_ExecServerMessage = 2 // ExecServerMessage
|
||||
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
|
||||
ASM_KvServerMessage = 4 // KvServerMessage
|
||||
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
|
||||
ASM_InteractionQuery = 7 // InteractionQuery
|
||||
)
|
||||
|
||||
// AgentRunRequest (msg 91)
|
||||
const (
|
||||
ARR_ConversationState = 1 // ConversationStateStructure
|
||||
ARR_Action = 2 // ConversationAction
|
||||
ARR_ModelDetails = 3 // ModelDetails
|
||||
ARR_McpTools = 4 // McpTools
|
||||
ARR_ConversationId = 5 // string (optional)
|
||||
)
|
||||
|
||||
// ConversationStateStructure (msg 83)
|
||||
const (
|
||||
CSS_RootPromptMessagesJson = 1 // repeated bytes
|
||||
CSS_TurnsOld = 2 // repeated bytes (deprecated)
|
||||
CSS_Todos = 3 // repeated bytes
|
||||
CSS_PendingToolCalls = 4 // repeated string
|
||||
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
|
||||
CSS_PreviousWorkspaceUris = 9 // repeated string
|
||||
CSS_SelfSummaryCount = 17 // uint32
|
||||
CSS_ReadPaths = 18 // repeated string
|
||||
)
|
||||
|
||||
// ConversationAction (msg 54) oneof "action"
|
||||
const (
|
||||
CA_UserMessageAction = 1 // UserMessageAction
|
||||
)
|
||||
|
||||
// UserMessageAction (msg 55)
|
||||
const (
|
||||
UMA_UserMessage = 1 // UserMessage
|
||||
)
|
||||
|
||||
// UserMessage (msg 63)
|
||||
const (
|
||||
UM_Text = 1 // string
|
||||
UM_MessageId = 2 // string
|
||||
UM_SelectedContext = 3 // SelectedContext (optional)
|
||||
)
|
||||
|
||||
// SelectedContext
|
||||
const (
|
||||
SC_SelectedImages = 1 // repeated SelectedImage
|
||||
)
|
||||
|
||||
// SelectedImage
|
||||
const (
|
||||
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
|
||||
SI_Uuid = 2 // string
|
||||
SI_Path = 3 // string
|
||||
SI_MimeType = 7 // string
|
||||
SI_Data = 8 // bytes (oneof dataOrBlobId)
|
||||
)
|
||||
|
||||
// ModelDetails (msg 88)
|
||||
const (
|
||||
MD_ModelId = 1 // string
|
||||
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
|
||||
MD_DisplayModelId = 3 // string
|
||||
MD_DisplayName = 4 // string
|
||||
)
|
||||
|
||||
// McpTools (msg 307)
|
||||
const (
|
||||
MT_McpTools = 1 // repeated McpToolDefinition
|
||||
)
|
||||
|
||||
// McpToolDefinition (msg 306)
|
||||
const (
|
||||
MTD_Name = 1 // string
|
||||
MTD_Description = 2 // string
|
||||
MTD_InputSchema = 3 // bytes
|
||||
MTD_ProviderIdentifier = 4 // string
|
||||
MTD_ToolName = 5 // string
|
||||
)
|
||||
|
||||
// ConversationTurnStructure (msg 70) oneof "turn"
|
||||
const (
|
||||
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
|
||||
)
|
||||
|
||||
// AgentConversationTurnStructure (msg 72)
|
||||
const (
|
||||
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
|
||||
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
|
||||
)
|
||||
|
||||
// ConversationStep (msg 53) oneof "message"
|
||||
const (
|
||||
CS_AssistantMessage = 1 // AssistantMessage
|
||||
)
|
||||
|
||||
// AssistantMessage
|
||||
const (
|
||||
AM_Text = 1 // string
|
||||
)
|
||||
|
||||
// --- Server-side message fields ---
|
||||
|
||||
// InteractionUpdate oneof "message"
|
||||
const (
|
||||
IU_TextDelta = 1 // TextDeltaUpdate
|
||||
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
|
||||
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
|
||||
)
|
||||
|
||||
// TextDeltaUpdate (msg 92)
|
||||
const (
|
||||
TDU_Text = 1 // string
|
||||
)
|
||||
|
||||
// ThinkingDeltaUpdate (msg 97)
|
||||
const (
|
||||
TKD_Text = 1 // string
|
||||
)
|
||||
|
||||
// KvServerMessage (msg 271)
|
||||
const (
|
||||
KSM_Id = 1 // uint32
|
||||
KSM_GetBlobArgs = 2 // GetBlobArgs
|
||||
KSM_SetBlobArgs = 3 // SetBlobArgs
|
||||
)
|
||||
|
||||
// GetBlobArgs (msg 267)
|
||||
const (
|
||||
GBA_BlobId = 1 // bytes
|
||||
)
|
||||
|
||||
// SetBlobArgs (msg 269)
|
||||
const (
|
||||
SBA_BlobId = 1 // bytes
|
||||
SBA_BlobData = 2 // bytes
|
||||
)
|
||||
|
||||
// KvClientMessage (msg 272)
|
||||
const (
|
||||
KCM_Id = 1 // uint32
|
||||
KCM_GetBlobResult = 2 // GetBlobResult
|
||||
KCM_SetBlobResult = 3 // SetBlobResult
|
||||
)
|
||||
|
||||
// GetBlobResult (msg 268)
|
||||
const (
|
||||
GBR_BlobData = 1 // bytes (optional)
|
||||
)
|
||||
|
||||
// ExecServerMessage
|
||||
const (
|
||||
ESM_Id = 1 // uint32
|
||||
ESM_ExecId = 15 // string
|
||||
// oneof message:
|
||||
ESM_ShellArgs = 2 // ShellArgs
|
||||
ESM_WriteArgs = 3 // WriteArgs
|
||||
ESM_DeleteArgs = 4 // DeleteArgs
|
||||
ESM_GrepArgs = 5 // GrepArgs
|
||||
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
|
||||
ESM_LsArgs = 8 // LsArgs
|
||||
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
|
||||
ESM_RequestContextArgs = 10 // RequestContextArgs
|
||||
ESM_McpArgs = 11 // McpArgs
|
||||
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
|
||||
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
|
||||
ESM_FetchArgs = 20 // FetchArgs
|
||||
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
|
||||
)
|
||||
|
||||
// ExecClientMessage
|
||||
const (
|
||||
ECM_Id = 1 // uint32
|
||||
ECM_ExecId = 15 // string
|
||||
// oneof message (mirrors server fields):
|
||||
ECM_ShellResult = 2
|
||||
ECM_WriteResult = 3
|
||||
ECM_DeleteResult = 4
|
||||
ECM_GrepResult = 5
|
||||
ECM_ReadResult = 7
|
||||
ECM_LsResult = 8
|
||||
ECM_DiagnosticsResult = 9
|
||||
ECM_RequestContextResult = 10
|
||||
ECM_McpResult = 11
|
||||
ECM_ShellStream = 14
|
||||
ECM_BackgroundShellSpawnRes = 16
|
||||
ECM_FetchResult = 20
|
||||
ECM_WriteShellStdinResult = 23
|
||||
)
|
||||
|
||||
// McpArgs
|
||||
const (
|
||||
MCA_Name = 1 // string
|
||||
MCA_Args = 2 // map<string, bytes>
|
||||
MCA_ToolCallId = 3 // string
|
||||
MCA_ProviderIdentifier = 4 // string
|
||||
MCA_ToolName = 5 // string
|
||||
)
|
||||
|
||||
// RequestContextResult oneof "result"
|
||||
const (
|
||||
RCR_Success = 1 // RequestContextSuccess
|
||||
RCR_Error = 2 // RequestContextError
|
||||
)
|
||||
|
||||
// RequestContextSuccess (msg 337)
|
||||
const (
|
||||
RCS_RequestContext = 1 // RequestContext
|
||||
)
|
||||
|
||||
// RequestContext
|
||||
const (
|
||||
RC_Rules = 2 // repeated CursorRule
|
||||
RC_Tools = 7 // repeated McpToolDefinition
|
||||
)
|
||||
|
||||
// McpResult oneof "result"
|
||||
const (
|
||||
MCR_Success = 1 // McpSuccess
|
||||
MCR_Error = 2 // McpError
|
||||
MCR_Rejected = 3 // McpRejected
|
||||
)
|
||||
|
||||
// McpSuccess (msg 290)
|
||||
const (
|
||||
MCS_Content = 1 // repeated McpToolResultContentItem
|
||||
MCS_IsError = 2 // bool
|
||||
)
|
||||
|
||||
// McpToolResultContentItem oneof "content"
|
||||
const (
|
||||
MTRCI_Text = 1 // McpTextContent
|
||||
)
|
||||
|
||||
// McpTextContent (msg 287)
|
||||
const (
|
||||
MTC_Text = 1 // string
|
||||
)
|
||||
|
||||
// McpError (msg 291)
|
||||
const (
|
||||
MCE_Error = 1 // string
|
||||
)
|
||||
|
||||
// --- Rejection messages ---
|
||||
|
||||
// ReadRejected: path=1, reason=2
|
||||
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
|
||||
// WriteRejected: path=1, reason=2
|
||||
// DeleteRejected: path=1, reason=2
|
||||
// LsRejected: path=1, reason=2
|
||||
// GrepError: error=1
|
||||
// FetchError: url=1, error=2
|
||||
// WriteShellStdinError: error=1
|
||||
|
||||
// ReadResult oneof: success=1, error=2, rejected=3
|
||||
// ShellResult oneof: success=1 (+ various), rejected=?
|
||||
// The TS code uses specific result field numbers from the oneof:
|
||||
const (
|
||||
RR_Rejected = 3 // ReadResult.rejected
|
||||
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
|
||||
WR_Rejected = 5 // WriteResult.rejected
|
||||
DR_Rejected = 3 // DeleteResult.rejected
|
||||
LR_Rejected = 3 // LsResult.rejected
|
||||
GR_Error = 2 // GrepResult.error
|
||||
FR_Error = 2 // FetchResult.error
|
||||
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
|
||||
WSSR_Error = 2 // WriteShellStdinResult.error
|
||||
)
|
||||
|
||||
// --- Rejection struct fields ---
|
||||
const (
|
||||
REJ_Path = 1
|
||||
REJ_Reason = 2
|
||||
SREJ_Command = 1
|
||||
SREJ_WorkingDir = 2
|
||||
SREJ_Reason = 3
|
||||
SREJ_IsReadonly = 4
|
||||
GERR_Error = 1
|
||||
FERR_Url = 1
|
||||
FERR_Error = 2
|
||||
)
|
||||
|
||||
// ReadArgs
|
||||
const (
|
||||
RA_Path = 1 // string
|
||||
)
|
||||
|
||||
// WriteArgs
|
||||
const (
|
||||
WA_Path = 1 // string
|
||||
)
|
||||
|
||||
// DeleteArgs
|
||||
const (
|
||||
DA_Path = 1 // string
|
||||
)
|
||||
|
||||
// LsArgs
|
||||
const (
|
||||
LA_Path = 1 // string
|
||||
)
|
||||
|
||||
// ShellArgs
|
||||
const (
|
||||
SHA_Command = 1 // string
|
||||
SHA_WorkingDirectory = 2 // string
|
||||
)
|
||||
|
||||
// FetchArgs
|
||||
const (
|
||||
FA_Url = 1 // string
|
||||
)
|
||||
313
internal/auth/cursor/proto/h2stream.go
Normal file
313
internal/auth/cursor/proto/h2stream.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultInitialWindowSize = 65535 // HTTP/2 default
|
||||
maxFramePayload = 16384 // HTTP/2 default max frame size
|
||||
)
|
||||
|
||||
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
||||
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
||||
type H2Stream struct {
|
||||
framer *http2.Framer
|
||||
conn net.Conn
|
||||
streamID uint32
|
||||
mu sync.Mutex
|
||||
id string // unique identifier for debugging
|
||||
frameNum int64 // sequential frame counter for debugging
|
||||
|
||||
dataCh chan []byte
|
||||
doneCh chan struct{}
|
||||
err error
|
||||
|
||||
// Send-side flow control
|
||||
sendWindow int32 // available bytes we can send on this stream
|
||||
connWindow int32 // available bytes on the connection level
|
||||
windowCond *sync.Cond // signaled when window is updated
|
||||
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||
}
|
||||
|
||||
// ID returns the unique identifier for this stream (for logging).
|
||||
func (s *H2Stream) ID() string { return s.id }
|
||||
|
||||
// FrameNum returns the current frame number for debugging.
|
||||
func (s *H2Stream) FrameNum() int64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.frameNum
|
||||
}
|
||||
|
||||
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
|
||||
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
|
||||
NextProtos: []string{"h2"},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
|
||||
}
|
||||
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: server did not negotiate h2")
|
||||
}
|
||||
|
||||
framer := http2.NewFramer(tlsConn, tlsConn)
|
||||
|
||||
// Client connection preface
|
||||
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
||||
}
|
||||
|
||||
// Send initial SETTINGS (tell server how much WE can receive)
|
||||
if err := framer.WriteSettings(
|
||||
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
||||
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
||||
); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
||||
}
|
||||
|
||||
// Connection-level window update (for receiving)
|
||||
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
||||
}
|
||||
|
||||
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
||||
// Track server's initial window size (how much WE can send)
|
||||
serverInitialWindowSize := int32(defaultInitialWindowSize)
|
||||
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
|
||||
for i := 0; i < 10; i++ {
|
||||
f, err := framer.ReadFrame()
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
|
||||
}
|
||||
switch sf := f.(type) {
|
||||
case *http2.SettingsFrame:
|
||||
if !sf.IsAck() {
|
||||
sf.ForeachSetting(func(s http2.Setting) error {
|
||||
if s.ID == http2.SettingInitialWindowSize {
|
||||
serverInitialWindowSize = int32(s.Val)
|
||||
log.Debugf("h2: server initial window size: %d", s.Val)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
framer.WriteSettingsAck()
|
||||
} else {
|
||||
goto handshakeDone
|
||||
}
|
||||
case *http2.WindowUpdateFrame:
|
||||
if sf.StreamID == 0 {
|
||||
connWindowSize += int32(sf.Increment)
|
||||
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
|
||||
}
|
||||
default:
|
||||
// unexpected but continue
|
||||
}
|
||||
}
|
||||
handshakeDone:
|
||||
|
||||
// Build HEADERS
|
||||
streamID := uint32(1)
|
||||
var hdrBuf []byte
|
||||
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
|
||||
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
|
||||
if p, ok := headers[":path"]; ok {
|
||||
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
|
||||
}
|
||||
for k, v := range headers {
|
||||
if len(k) > 0 && k[0] == ':' {
|
||||
continue
|
||||
}
|
||||
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
}
|
||||
|
||||
if err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: streamID,
|
||||
BlockFragment: hdrBuf,
|
||||
EndStream: false,
|
||||
EndHeaders: true,
|
||||
}); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("h2: headers write failed: %w", err)
|
||||
}
|
||||
|
||||
s := &H2Stream{
|
||||
framer: framer,
|
||||
conn: tlsConn,
|
||||
streamID: streamID,
|
||||
dataCh: make(chan []byte, 256),
|
||||
doneCh: make(chan struct{}),
|
||||
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||
frameNum: 0,
|
||||
sendWindow: serverInitialWindowSize,
|
||||
connWindow: connWindowSize,
|
||||
}
|
||||
s.windowCond = sync.NewCond(&s.windowMu)
|
||||
go s.readLoop()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Write sends a DATA frame on the stream, respecting flow control.
|
||||
func (s *H2Stream) Write(data []byte) error {
|
||||
for len(data) > 0 {
|
||||
chunk := data
|
||||
if len(chunk) > maxFramePayload {
|
||||
chunk = data[:maxFramePayload]
|
||||
}
|
||||
|
||||
// Wait for flow control window
|
||||
s.windowMu.Lock()
|
||||
for s.sendWindow <= 0 || s.connWindow <= 0 {
|
||||
s.windowCond.Wait()
|
||||
}
|
||||
// Limit chunk to available window
|
||||
allowed := int(s.sendWindow)
|
||||
if int(s.connWindow) < allowed {
|
||||
allowed = int(s.connWindow)
|
||||
}
|
||||
if len(chunk) > allowed {
|
||||
chunk = chunk[:allowed]
|
||||
}
|
||||
s.sendWindow -= int32(len(chunk))
|
||||
s.connWindow -= int32(len(chunk))
|
||||
s.windowMu.Unlock()
|
||||
|
||||
s.mu.Lock()
|
||||
err := s.framer.WriteData(s.streamID, false, chunk)
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = data[len(chunk):]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns the channel of received data chunks.
|
||||
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
|
||||
|
||||
// Done returns a channel closed when the stream ends.
|
||||
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
|
||||
|
||||
// Err returns the error (if any) that caused the stream to close.
|
||||
// Returns nil for a clean shutdown (EOF / StreamEnded).
|
||||
func (s *H2Stream) Err() error { return s.err }
|
||||
|
||||
// Close tears down the connection.
|
||||
func (s *H2Stream) Close() {
|
||||
s.conn.Close()
|
||||
// Unblock any writers waiting on flow control
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
|
||||
func (s *H2Stream) readLoop() {
|
||||
defer close(s.doneCh)
|
||||
defer close(s.dataCh)
|
||||
|
||||
for {
|
||||
f, err := s.framer.ReadFrame()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
s.err = err
|
||||
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Increment frame counter
|
||||
s.mu.Lock()
|
||||
s.frameNum++
|
||||
s.mu.Unlock()
|
||||
|
||||
switch frame := f.(type) {
|
||||
case *http2.DataFrame:
|
||||
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
||||
cp := make([]byte, len(frame.Data()))
|
||||
copy(cp, frame.Data())
|
||||
s.dataCh <- cp
|
||||
|
||||
// Flow control: send WINDOW_UPDATE for received data
|
||||
s.mu.Lock()
|
||||
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
||||
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
||||
s.mu.Unlock()
|
||||
}
|
||||
if frame.StreamEnded() {
|
||||
return
|
||||
}
|
||||
|
||||
case *http2.HeadersFrame:
|
||||
if frame.StreamEnded() {
|
||||
return
|
||||
}
|
||||
|
||||
case *http2.RSTStreamFrame:
|
||||
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
||||
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
|
||||
return
|
||||
|
||||
case *http2.GoAwayFrame:
|
||||
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
||||
return
|
||||
|
||||
case *http2.PingFrame:
|
||||
if !frame.IsAck() {
|
||||
s.mu.Lock()
|
||||
s.framer.WritePing(true, frame.Data)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
case *http2.SettingsFrame:
|
||||
if !frame.IsAck() {
|
||||
// Check for window size changes
|
||||
frame.ForeachSetting(func(setting http2.Setting) error {
|
||||
if setting.ID == http2.SettingInitialWindowSize {
|
||||
s.windowMu.Lock()
|
||||
delta := int32(setting.Val) - s.sendWindow
|
||||
s.sendWindow += delta
|
||||
s.windowMu.Unlock()
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.mu.Lock()
|
||||
s.framer.WriteSettingsAck()
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
case *http2.WindowUpdateFrame:
|
||||
// Update send-side flow control window
|
||||
s.windowMu.Lock()
|
||||
if frame.StreamID == 0 {
|
||||
s.connWindow += int32(frame.Increment)
|
||||
} else if frame.StreamID == s.streamID {
|
||||
s.sendWindow += int32(frame.Increment)
|
||||
}
|
||||
s.windowMu.Unlock()
|
||||
s.windowCond.Broadcast()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sliceWriter struct{ buf *[]byte }
|
||||
|
||||
func (w *sliceWriter) Write(p []byte) (int, error) {
|
||||
*w.buf = append(*w.buf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -305,6 +305,9 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
var manualInputCh <-chan string
|
||||
var manualInputErrCh <-chan error
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
@@ -326,13 +329,14 @@ waitForCallback:
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := misc.ParseOAuthCallback(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||
continue
|
||||
case input := <-manualInputCh:
|
||||
manualInputCh = nil
|
||||
manualInputErrCh = nil
|
||||
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
@@ -345,6 +349,8 @@ waitForCallback:
|
||||
}
|
||||
authCode = parsed.Code
|
||||
break waitForCallback
|
||||
case errManual := <-manualInputErrCh:
|
||||
return nil, errManual
|
||||
case <-timeoutTimer.C:
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
}
|
||||
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
)
|
||||
|
||||
// newAuthManager creates a new authentication manager instance with all supported
|
||||
// authenticators and a file-based token store. It initializes authenticators for
|
||||
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
|
||||
// authenticators and a file-based token store.
|
||||
//
|
||||
// Returns:
|
||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||
@@ -24,6 +23,8 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
sdkAuth.NewCodeBuddyAuthenticator(),
|
||||
sdkAuth.NewCursorAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
43
internal/cmd/codebuddy_login.go
Normal file
43
internal/cmd/codebuddy_login.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens.
|
||||
// It initiates the OAuth authentication, displays the user code for the user to enter
|
||||
// at the CodeBuddy verification URL, and waits for authorization before saving the tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy and auth directory settings
|
||||
// - options: Login options including browser behavior settings
|
||||
func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("CodeBuddy authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("CodeBuddy authentication successful!")
|
||||
}
|
||||
37
internal/cmd/cursor_login.go
Normal file
37
internal/cmd/cursor_login.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
|
||||
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("Cursor authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
log.Infof("Authentication saved to %s", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
log.Infof("Authenticated as %s", record.Label)
|
||||
}
|
||||
log.Info("Cursor authentication successful!")
|
||||
}
|
||||
55
internal/config/claude_header_defaults_test.go
Normal file
55
internal/config/claude_header_defaults_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfigOptional_ClaudeHeaderDefaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
configYAML := []byte(`
|
||||
claude-header-defaults:
|
||||
user-agent: " claude-cli/2.1.70 (external, cli) "
|
||||
package-version: " 0.80.0 "
|
||||
runtime-version: " v24.5.0 "
|
||||
os: " MacOS "
|
||||
arch: " arm64 "
|
||||
timeout: " 900 "
|
||||
stabilize-device-profile: false
|
||||
`)
|
||||
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfigOptional(configPath, false)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||
}
|
||||
|
||||
if got := cfg.ClaudeHeaderDefaults.UserAgent; got != "claude-cli/2.1.70 (external, cli)" {
|
||||
t.Fatalf("UserAgent = %q, want %q", got, "claude-cli/2.1.70 (external, cli)")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.PackageVersion; got != "0.80.0" {
|
||||
t.Fatalf("PackageVersion = %q, want %q", got, "0.80.0")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.RuntimeVersion; got != "v24.5.0" {
|
||||
t.Fatalf("RuntimeVersion = %q, want %q", got, "v24.5.0")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.OS; got != "MacOS" {
|
||||
t.Fatalf("OS = %q, want %q", got, "MacOS")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.Arch; got != "arm64" {
|
||||
t.Fatalf("Arch = %q, want %q", got, "arm64")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.Timeout; got != "900" {
|
||||
t.Fatalf("Timeout = %q, want %q", got, "900")
|
||||
}
|
||||
if cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||
t.Fatal("StabilizeDeviceProfile = nil, want non-nil")
|
||||
}
|
||||
if got := *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile; got {
|
||||
t.Fatalf("StabilizeDeviceProfile = %v, want false", got)
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -145,13 +146,19 @@ type Config struct {
|
||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests.
|
||||
// In legacy mode, UserAgent/PackageVersion/RuntimeVersion/Timeout act as fallbacks when
|
||||
// the client omits them, while OS/Arch remain runtime-derived. When stabilized device
|
||||
// profiles are enabled, OS/Arch become the pinned platform baseline, while
|
||||
// UserAgent/PackageVersion/RuntimeVersion seed the upgradeable software fingerprint.
|
||||
type ClaudeHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
OS string `yaml:"os" json:"os"`
|
||||
Arch string `yaml:"arch" json:"arch"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
StabilizeDeviceProfile *bool `yaml:"stabilize-device-profile,omitempty" json:"stabilize-device-profile,omitempty"`
|
||||
}
|
||||
|
||||
// CodexHeaderDefaults configures fallback header values injected into Codex
|
||||
@@ -188,6 +195,9 @@ type RemoteManagement struct {
|
||||
SecretKey string `yaml:"secret-key"`
|
||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||
// DisableAutoUpdatePanel disables automatic periodic background updates of the management panel asset from GitHub.
|
||||
// When false (the default), the background updater remains enabled; when true, the panel is only downloaded on first access if missing.
|
||||
DisableAutoUpdatePanel bool `yaml:"disable-auto-update-panel"`
|
||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||
@@ -568,6 +578,10 @@ type OpenAICompatibilityModel struct {
|
||||
|
||||
// Alias is the model name alias that clients will use to reference this model.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
|
||||
// Thinking configures the thinking/reasoning capability for this model.
|
||||
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
|
||||
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
||||
@@ -694,6 +708,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Sanitize Codex header defaults.
|
||||
cfg.SanitizeCodexHeaderDefaults()
|
||||
|
||||
// Sanitize Claude header defaults.
|
||||
cfg.SanitizeClaudeHeaderDefaults()
|
||||
|
||||
// Sanitize Claude key headers
|
||||
cfg.SanitizeClaudeKeys()
|
||||
|
||||
@@ -796,6 +813,20 @@ func (cfg *Config) SanitizeCodexHeaderDefaults() {
|
||||
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||
}
|
||||
|
||||
// SanitizeClaudeHeaderDefaults trims surrounding whitespace from the
|
||||
// configured Claude fingerprint baseline values.
|
||||
func (cfg *Config) SanitizeClaudeHeaderDefaults() {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.ClaudeHeaderDefaults.UserAgent = strings.TrimSpace(cfg.ClaudeHeaderDefaults.UserAgent)
|
||||
cfg.ClaudeHeaderDefaults.PackageVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.PackageVersion)
|
||||
cfg.ClaudeHeaderDefaults.RuntimeVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.RuntimeVersion)
|
||||
cfg.ClaudeHeaderDefaults.OS = strings.TrimSpace(cfg.ClaudeHeaderDefaults.OS)
|
||||
cfg.ClaudeHeaderDefaults.Arch = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Arch)
|
||||
cfg.ClaudeHeaderDefaults.Timeout = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Timeout)
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
managementSyncMinInterval = 30 * time.Second
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
maxAssetDownloadSize = 50 << 20 // 10 MB safety limit for management asset downloads
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -88,6 +89,10 @@ func runAutoUpdater(ctx context.Context) {
|
||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||
return
|
||||
}
|
||||
if cfg.RemoteManagement.DisableAutoUpdatePanel {
|
||||
log.Debug("management asset auto-updater skipped: disable-auto-update-panel is enabled")
|
||||
return
|
||||
}
|
||||
|
||||
configPath, _ := schedulerConfigPath.Load().(string)
|
||||
staticDir := StaticDir(configPath)
|
||||
@@ -259,7 +264,8 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
log.Errorf("management asset digest mismatch: expected %s got %s — aborting update for safety", remoteHash, downloadedHash)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
@@ -282,6 +288,9 @@ func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, loca
|
||||
return false
|
||||
}
|
||||
|
||||
log.Warnf("management asset downloaded from fallback URL without digest verification (hash=%s) — "+
|
||||
"enable verified GitHub updates by keeping disable-auto-update-panel set to false", downloadedHash)
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||
return false
|
||||
@@ -392,10 +401,13 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string)
|
||||
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("read download body: %w", err)
|
||||
}
|
||||
if int64(len(data)) > maxAssetDownloadSize {
|
||||
return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize)
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(data)
|
||||
return data, hex.EncodeToString(sum[:]), nil
|
||||
|
||||
@@ -30,6 +30,23 @@ type OAuthCallback struct {
|
||||
ErrorDescription string
|
||||
}
|
||||
|
||||
// AsyncPrompt runs a prompt function in a goroutine and returns channels for
|
||||
// the result. The returned channels are buffered (size 1) so the goroutine can
|
||||
// complete even if the caller abandons the channels.
|
||||
func AsyncPrompt(promptFn func(string) (string, error), message string) (<-chan string, <-chan error) {
|
||||
inputCh := make(chan string, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
input, err := promptFn(message)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
inputCh <- input
|
||||
}()
|
||||
return inputCh, errCh
|
||||
}
|
||||
|
||||
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
||||
// It returns nil when the input is empty.
|
||||
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
||||
|
||||
@@ -88,6 +88,87 @@ func GetAntigravityModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Antigravity)
|
||||
}
|
||||
|
||||
// GetCodeBuddyModels returns the available models for CodeBuddy (Tencent).
|
||||
// These models are served through the copilot.tencent.com API.
|
||||
func GetCodeBuddyModels() []*ModelInfo {
|
||||
now := int64(1748044800) // 2025-05-24
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "glm-5.0",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "GLM-5.0",
|
||||
Description: "GLM-5.0 via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "glm-4.7",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "GLM-4.7",
|
||||
Description: "GLM-4.7 via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "minimax-m2.5",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "MiniMax M2.5",
|
||||
Description: "MiniMax M2.5 via CodeBuddy",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2.5",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "Kimi K2.5",
|
||||
Description: "Kimi K2.5 via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "deepseek-v3-2-volc",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "DeepSeek V3.2 (Volc)",
|
||||
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "hunyuan-2.0-thinking",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "Hunyuan 2.0 Thinking",
|
||||
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
@@ -148,11 +229,27 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetAmazonQModels()
|
||||
case "antigravity":
|
||||
return GetAntigravityModels()
|
||||
case "codebuddy":
|
||||
return GetCodeBuddyModels()
|
||||
case "cursor":
|
||||
return GetCursorModels()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetCursorModels returns the fallback Cursor model definitions.
|
||||
func GetCursorModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
|
||||
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
|
||||
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
|
||||
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||
}
|
||||
}
|
||||
|
||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||
// Returns nil if no matching model is found.
|
||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
@@ -176,6 +273,8 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetKiroModels(),
|
||||
GetKiloModels(),
|
||||
GetAmazonQModels(),
|
||||
GetCodeBuddyModels(),
|
||||
GetCursorModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
@@ -365,6 +464,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.4",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.4",
|
||||
Description: "OpenAI GPT-5.4 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
Object: "model",
|
||||
|
||||
@@ -73,16 +73,16 @@ type availableModelsCacheEntry struct {
|
||||
// Values are interpreted in provider-native token units.
|
||||
type ThinkingSupport struct {
|
||||
// Min is the minimum allowed thinking budget (inclusive).
|
||||
Min int `json:"min,omitempty"`
|
||||
Min int `json:"min,omitempty" yaml:"min,omitempty"`
|
||||
// Max is the maximum allowed thinking budget (inclusive).
|
||||
Max int `json:"max,omitempty"`
|
||||
Max int `json:"max,omitempty" yaml:"max,omitempty"`
|
||||
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
|
||||
ZeroAllowed bool `json:"zero_allowed,omitempty"`
|
||||
ZeroAllowed bool `json:"zero_allowed,omitempty" yaml:"zero-allowed,omitempty"`
|
||||
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
|
||||
DynamicAllowed bool `json:"dynamic_allowed,omitempty"`
|
||||
DynamicAllowed bool `json:"dynamic_allowed,omitempty" yaml:"dynamic-allowed,omitempty"`
|
||||
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
|
||||
// When set, the model uses level-based reasoning instead of token budgets.
|
||||
Levels []string `json:"levels,omitempty"`
|
||||
Levels []string `json:"levels,omitempty" yaml:"levels,omitempty"`
|
||||
}
|
||||
|
||||
// ModelRegistration tracks a model's availability
|
||||
|
||||
@@ -164,7 +164,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -296,7 +296,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(event.Payload))
|
||||
return false
|
||||
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||
}
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for AI Studio).
|
||||
|
||||
@@ -308,7 +308,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -512,7 +512,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
@@ -691,31 +691,42 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
}
|
||||
|
||||
partsJSON, _ := json.Marshal(parts)
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
|
||||
updatedTemplate, _ := sjson.SetRawBytes([]byte(responseTemplate), "candidates.0.content.parts", partsJSON)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
if role != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.content.role", role)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
if finishReason != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.finishReason", finishReason)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
if modelVersion != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "modelVersion", modelVersion)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
if responseID != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "responseId", responseID)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
if usageRaw != "" {
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
|
||||
updatedTemplate, _ = sjson.SetRawBytes([]byte(responseTemplate), "usageMetadata", []byte(usageRaw))
|
||||
responseTemplate = string(updatedTemplate)
|
||||
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.promptTokenCount", 0)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.candidatesTokenCount", 0)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.totalTokenCount", 0)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
|
||||
output := `{"response":{},"traceId":""}`
|
||||
output, _ = sjson.SetRaw(output, "response", responseTemplate)
|
||||
updatedOutput, _ := sjson.SetRawBytes([]byte(output), "response", []byte(responseTemplate))
|
||||
output = string(updatedOutput)
|
||||
if traceID != "" {
|
||||
output, _ = sjson.Set(output, "traceId", traceID)
|
||||
updatedOutput, _ = sjson.SetBytes([]byte(output), "traceId", traceID)
|
||||
output = string(updatedOutput)
|
||||
}
|
||||
return []byte(output)
|
||||
}
|
||||
@@ -880,12 +891,12 @@ attemptLoop:
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -1043,7 +1054,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
@@ -1265,19 +1276,20 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
|
||||
// if useAntigravitySchema {
|
||||
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.role", "user")
|
||||
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
|
||||
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
// for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
// payloadStr, _ = sjson.SetRawBytes([]byte(payloadStr), "request.systemInstruction.parts.-1", []byte(partResult.Raw))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
payloadStr = string(updated)
|
||||
} else {
|
||||
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
@@ -1499,8 +1511,9 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
||||
}
|
||||
|
||||
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||
template := payload
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "userAgent", "antigravity")
|
||||
|
||||
isImageModel := strings.Contains(modelName, "image")
|
||||
|
||||
@@ -1510,28 +1523,28 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
} else {
|
||||
reqType = "agent"
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestType", reqType)
|
||||
template, _ = sjson.SetBytes(template, "requestType", reqType)
|
||||
|
||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||
if projectID != "" {
|
||||
template, _ = sjson.Set(template, "project", projectID)
|
||||
template, _ = sjson.SetBytes(template, "project", projectID)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||
template, _ = sjson.SetBytes(template, "project", generateProjectID())
|
||||
}
|
||||
|
||||
if isImageModel {
|
||||
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
|
||||
template, _ = sjson.SetBytes(template, "requestId", generateImageGenRequestID())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
template, _ = sjson.SetBytes(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.SetBytes(template, "request.sessionId", generateStableSessionID(payload))
|
||||
}
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
|
||||
template, _ = sjson.Delete(template, "toolConfig")
|
||||
template, _ = sjson.DeleteBytes(template, "request.safetySettings")
|
||||
if toolConfig := gjson.GetBytes(template, "toolConfig"); toolConfig.Exists() && !gjson.GetBytes(template, "request.toolConfig").Exists() {
|
||||
template, _ = sjson.SetRawBytes(template, "request.toolConfig", []byte(toolConfig.Raw))
|
||||
template, _ = sjson.DeleteBytes(template, "toolConfig")
|
||||
}
|
||||
return []byte(template)
|
||||
return template
|
||||
}
|
||||
|
||||
func generateRequestID() string {
|
||||
|
||||
383
internal/runtime/executor/claude_device_profile.go
Normal file
383
internal/runtime/executor/claude_device_profile.go
Normal file
@@ -0,0 +1,383 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClaudeFingerprintUserAgent = "claude-cli/2.1.63 (external, cli)"
|
||||
defaultClaudeFingerprintPackageVersion = "0.74.0"
|
||||
defaultClaudeFingerprintRuntimeVersion = "v24.3.0"
|
||||
defaultClaudeFingerprintOS = "MacOS"
|
||||
defaultClaudeFingerprintArch = "arm64"
|
||||
claudeDeviceProfileTTL = 7 * 24 * time.Hour
|
||||
claudeDeviceProfileCleanupPeriod = time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
claudeCLIVersionPattern = regexp.MustCompile(`^claude-cli/(\d+)\.(\d+)\.(\d+)`)
|
||||
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu sync.RWMutex
|
||||
claudeDeviceProfileCacheCleanupOnce sync.Once
|
||||
|
||||
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
|
||||
)
|
||||
|
||||
type claudeCLIVersion struct {
|
||||
major int
|
||||
minor int
|
||||
patch int
|
||||
}
|
||||
|
||||
func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
|
||||
switch {
|
||||
case v.major != other.major:
|
||||
if v.major > other.major {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
case v.minor != other.minor:
|
||||
if v.minor > other.minor {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
case v.patch != other.patch:
|
||||
if v.patch > other.patch {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
type claudeDeviceProfile struct {
|
||||
UserAgent string
|
||||
PackageVersion string
|
||||
RuntimeVersion string
|
||||
OS string
|
||||
Arch string
|
||||
Version claudeCLIVersion
|
||||
HasVersion bool
|
||||
}
|
||||
|
||||
type claudeDeviceProfileCacheEntry struct {
|
||||
profile claudeDeviceProfile
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
||||
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||
return false
|
||||
}
|
||||
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
|
||||
}
|
||||
|
||||
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if strings.TrimSpace(cfgVal) != "" {
|
||||
return strings.TrimSpace(cfgVal)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hd config.ClaudeHeaderDefaults
|
||||
if cfg != nil {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
profile := claudeDeviceProfile{
|
||||
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
|
||||
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
|
||||
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
|
||||
OS: hdrDefault(hd.OS, defaultClaudeFingerprintOS),
|
||||
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
|
||||
}
|
||||
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||
profile.Version = version
|
||||
profile.HasVersion = true
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
|
||||
matches := claudeCLIVersionPattern.FindStringSubmatch(strings.TrimSpace(userAgent))
|
||||
if len(matches) != 4 {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
major, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
minor, err := strconv.Atoi(matches[2])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
patch, err := strconv.Atoi(matches[3])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
|
||||
}
|
||||
|
||||
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
|
||||
if candidate.UserAgent == "" || !candidate.HasVersion {
|
||||
return false
|
||||
}
|
||||
if current.UserAgent == "" || !current.HasVersion {
|
||||
return true
|
||||
}
|
||||
return candidate.Version.Compare(current.Version) > 0
|
||||
}
|
||||
|
||||
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||
profile.OS = baseline.OS
|
||||
profile.Arch = baseline.Arch
|
||||
return profile
|
||||
}
|
||||
|
||||
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
|
||||
// baseline platform and enforces the baseline software fingerprint as a floor.
|
||||
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
|
||||
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
||||
profile.UserAgent = baseline.UserAgent
|
||||
profile.PackageVersion = baseline.PackageVersion
|
||||
profile.RuntimeVersion = baseline.RuntimeVersion
|
||||
profile.Version = baseline.Version
|
||||
profile.HasVersion = baseline.HasVersion
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
|
||||
if headers == nil {
|
||||
return claudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
||||
version, ok := parseClaudeCLIVersion(userAgent)
|
||||
if !ok {
|
||||
return claudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
profile := claudeDeviceProfile{
|
||||
UserAgent: userAgent,
|
||||
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
|
||||
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
|
||||
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
|
||||
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
|
||||
Version: version,
|
||||
HasVersion: true,
|
||||
}
|
||||
return profile, true
|
||||
}
|
||||
|
||||
func firstNonEmptyHeader(headers http.Header, name, fallback string) string {
|
||||
if headers == nil {
|
||||
return fallback
|
||||
}
|
||||
if value := strings.TrimSpace(headers.Get(name)); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func claudeDeviceProfileScopeKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||
switch {
|
||||
case auth != nil && strings.TrimSpace(auth.ID) != "":
|
||||
return "auth:" + strings.TrimSpace(auth.ID)
|
||||
case strings.TrimSpace(apiKey) != "":
|
||||
return "api_key:" + strings.TrimSpace(apiKey)
|
||||
default:
|
||||
return "global"
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDeviceProfileCacheKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(claudeDeviceProfileScopeKey(auth, apiKey)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func startClaudeDeviceProfileCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(claudeDeviceProfileCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredClaudeDeviceProfiles()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredClaudeDeviceProfiles() {
|
||||
now := time.Now()
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
for key, entry := range claudeDeviceProfileCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(claudeDeviceProfileCache, key)
|
||||
}
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
|
||||
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
|
||||
|
||||
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
|
||||
now := time.Now()
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
candidate, hasCandidate := extractClaudeDeviceProfile(headers, cfg)
|
||||
if hasCandidate {
|
||||
candidate = pinClaudeDeviceProfilePlatform(candidate, baseline)
|
||||
}
|
||||
if hasCandidate && !shouldUpgradeClaudeDeviceProfile(candidate, baseline) {
|
||||
hasCandidate = false
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.RLock()
|
||||
entry, hasCached := claudeDeviceProfileCache[cacheKey]
|
||||
cachedValid := hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||
claudeDeviceProfileCacheMu.RUnlock()
|
||||
|
||||
if hasCandidate {
|
||||
if claudeDeviceProfileBeforeCandidateStore != nil {
|
||||
claudeDeviceProfileBeforeCandidateStore(candidate)
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
entry, hasCached = claudeDeviceProfileCache[cacheKey]
|
||||
cachedValid = hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||
if cachedValid {
|
||||
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||
}
|
||||
if cachedValid && !shouldUpgradeClaudeDeviceProfile(candidate, entry.profile) {
|
||||
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||
claudeDeviceProfileCache[cacheKey] = entry
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return entry.profile
|
||||
}
|
||||
|
||||
claudeDeviceProfileCache[cacheKey] = claudeDeviceProfileCacheEntry{
|
||||
profile: candidate,
|
||||
expire: now.Add(claudeDeviceProfileTTL),
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return candidate
|
||||
}
|
||||
|
||||
if cachedValid {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
entry = claudeDeviceProfileCache[cacheKey]
|
||||
if entry.expire.After(now) && entry.profile.UserAgent != "" {
|
||||
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||
claudeDeviceProfileCache[cacheKey] = entry
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return entry.profile
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
return baseline
|
||||
}
|
||||
|
||||
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
for _, headerName := range []string{
|
||||
"User-Agent",
|
||||
"X-Stainless-Package-Version",
|
||||
"X-Stainless-Runtime-Version",
|
||||
"X-Stainless-Os",
|
||||
"X-Stainless-Arch",
|
||||
} {
|
||||
r.Header.Del(headerName)
|
||||
}
|
||||
r.Header.Set("User-Agent", profile.UserAgent)
|
||||
r.Header.Set("X-Stainless-Package-Version", profile.PackageVersion)
|
||||
r.Header.Set("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||
r.Header.Set("X-Stainless-Os", profile.OS)
|
||||
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
||||
}
|
||||
|
||||
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
profile := defaultClaudeDeviceProfile(cfg)
|
||||
miscEnsure := func(name, fallback string) {
|
||||
if strings.TrimSpace(r.Header.Get(name)) != "" {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ginHeaders.Get(name)) != "" {
|
||||
r.Header.Set(name, strings.TrimSpace(ginHeaders.Get(name)))
|
||||
return
|
||||
}
|
||||
r.Header.Set(name, fallback)
|
||||
}
|
||||
|
||||
miscEnsure("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||
miscEnsure("X-Stainless-Package-Version", profile.PackageVersion)
|
||||
miscEnsure("X-Stainless-Os", mapStainlessOS())
|
||||
miscEnsure("X-Stainless-Arch", mapStainlessArch())
|
||||
|
||||
// Legacy mode preserves per-auth custom header overrides. By the time we get
|
||||
// here, ApplyCustomHeadersFromAttrs has already populated r.Header.
|
||||
if strings.TrimSpace(r.Header.Get("User-Agent")) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = strings.TrimSpace(ginHeaders.Get("User-Agent"))
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
return
|
||||
}
|
||||
r.Header.Set("User-Agent", profile.UserAgent)
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -255,7 +254,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
data,
|
||||
¶m,
|
||||
)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -443,7 +442,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
¶m,
|
||||
)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
@@ -561,7 +560,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
@@ -767,36 +766,6 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if cfgVal != "" {
|
||||
@@ -824,6 +793,11 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
|
||||
var deviceProfile claudeDeviceProfile
|
||||
if stabilizeDeviceProfile {
|
||||
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
|
||||
}
|
||||
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
@@ -867,25 +841,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
// For User-Agent, only forward the client's header if it's already a Claude Code client.
|
||||
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
|
||||
// to avoid leaking the real client identity during cloaking.
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = ginHeaders.Get("User-Agent")
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
} else {
|
||||
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
|
||||
}
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
@@ -897,13 +855,19 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
r.Header.Set("Accept", "application/json")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
// Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
|
||||
// to the configured baseline while still allowing newer official
|
||||
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
if stabilizeDeviceProfile {
|
||||
applyClaudeDeviceProfileHeaders(r, deviceProfile)
|
||||
} else {
|
||||
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
|
||||
}
|
||||
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||
@@ -1260,7 +1224,8 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
|
||||
partJSON := part.Raw
|
||||
if !part.Get("cache_control").Exists() {
|
||||
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
|
||||
updated, _ := sjson.SetBytes([]byte(partJSON), "cache_control.type", "ephemeral")
|
||||
partJSON = string(updated)
|
||||
}
|
||||
result += "," + partJSON
|
||||
}
|
||||
@@ -1268,7 +1233,8 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
})
|
||||
} else if system.Type == gjson.String && system.String() != "" {
|
||||
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", system.String())
|
||||
updated, _ := sjson.SetBytes([]byte(partJSON), "text", system.String())
|
||||
partJSON = string(updated)
|
||||
result += "," + partJSON
|
||||
}
|
||||
result += "]"
|
||||
|
||||
@@ -8,8 +8,11 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -19,6 +22,587 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func resetClaudeDeviceProfileCache() {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginReq := httptest.NewRequest(http.MethodPost, "http://localhost/v1/messages", nil)
|
||||
ginReq.Header = incoming.Clone()
|
||||
ginCtx.Request = ginReq
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
|
||||
return req.WithContext(context.WithValue(req.Context(), "gin", ginCtx))
|
||||
}
|
||||
|
||||
func assertClaudeFingerprint(t *testing.T, headers http.Header, userAgent, pkgVersion, runtimeVersion, osName, arch string) {
|
||||
t.Helper()
|
||||
|
||||
if got := headers.Get("User-Agent"); got != userAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", got, userAgent)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Package-Version"); got != pkgVersion {
|
||||
t.Fatalf("X-Stainless-Package-Version = %q, want %q", got, pkgVersion)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Runtime-Version"); got != runtimeVersion {
|
||||
t.Fatalf("X-Stainless-Runtime-Version = %q, want %q", got, runtimeVersion)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Os"); got != osName {
|
||||
t.Fatalf("X-Stainless-Os = %q, want %q", got, osName)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Arch"); got != arch {
|
||||
t.Fatalf("X-Stainless-Arch = %q, want %q", got, arch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
Timeout: "900",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline",
|
||||
"header:User-Agent": "evil-client/9.9",
|
||||
"header:X-Stainless-Os": "Linux",
|
||||
"header:X-Stainless-Arch": "x64",
|
||||
"header:X-Stainless-Package-Version": "9.9.9",
|
||||
},
|
||||
}
|
||||
incoming := http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, incoming)
|
||||
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
|
||||
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_TracksHighestClaudeCLIFingerprint(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-upgrade",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
firstReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(firstReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"lobe-chat/1.0"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
|
||||
|
||||
higherReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||
"X-Stainless-Os": []string{"MacOS"},
|
||||
"X-Stainless-Arch": []string{"arm64"},
|
||||
})
|
||||
applyClaudeHeaders(higherReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, higherReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
|
||||
|
||||
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.73.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(lowerReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_DoesNotDowngradeConfiguredBaselineOnFirstClaudeClient(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline-floor",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline-floor",
|
||||
},
|
||||
}
|
||||
|
||||
olderClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(olderClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, olderClaudeReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
|
||||
newerClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.81.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(newerClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, newerClaudeReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UpgradesCachedSoftwareFingerprintWhenBaselineAdvances(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
oldCfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.77 (external, cli)",
|
||||
PackageVersion: "0.87.0",
|
||||
RuntimeVersion: "v24.8.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline-reload",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline-reload",
|
||||
},
|
||||
}
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.81.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-baseline-reload", false, nil, oldCfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-baseline-reload", false, nil, newCfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LearnsOfficialFingerprintAfterCustomBaselineFallback(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "my-gateway/1.0",
|
||||
PackageVersion: "custom-pkg",
|
||||
RuntimeVersion: "custom-runtime",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-custom-baseline-learning",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-custom-baseline-learning",
|
||||
},
|
||||
}
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "my-gateway/1.0", "custom-pkg", "custom-runtime", "MacOS", "arm64")
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.87.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
|
||||
postLearningThirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(postLearningThirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, postLearningThirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-racy-upgrade",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-racy-upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
lowPaused := make(chan struct{})
|
||||
releaseLow := make(chan struct{})
|
||||
var pauseOnce sync.Once
|
||||
var releaseOnce sync.Once
|
||||
|
||||
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
|
||||
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
|
||||
return
|
||||
}
|
||||
pauseOnce.Do(func() { close(lowPaused) })
|
||||
<-releaseLow
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
claudeDeviceProfileBeforeCandidateStore = nil
|
||||
releaseOnce.Do(func() { close(releaseLow) })
|
||||
})
|
||||
|
||||
lowResultCh := make(chan claudeDeviceProfile, 1)
|
||||
go func() {
|
||||
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
}, cfg)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lowPaused:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for lower candidate to pause before storing")
|
||||
}
|
||||
|
||||
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||
"X-Stainless-Os": []string{"MacOS"},
|
||||
"X-Stainless-Arch": []string{"arm64"},
|
||||
}, cfg)
|
||||
releaseOnce.Do(func() { close(releaseLow) })
|
||||
|
||||
select {
|
||||
case lowResult := <-lowResultCh:
|
||||
if lowResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("lowResult.UserAgent = %q, want %q", lowResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if lowResult.PackageVersion != "0.75.0" {
|
||||
t.Fatalf("lowResult.PackageVersion = %q, want %q", lowResult.PackageVersion, "0.75.0")
|
||||
}
|
||||
if lowResult.OS != "MacOS" || lowResult.Arch != "arm64" {
|
||||
t.Fatalf("lowResult platform = %s/%s, want %s/%s", lowResult.OS, lowResult.Arch, "MacOS", "arm64")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for lower candidate result")
|
||||
}
|
||||
|
||||
if highResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("highResult.UserAgent = %q, want %q", highResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if highResult.OS != "MacOS" || highResult.Arch != "arm64" {
|
||||
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
|
||||
}
|
||||
|
||||
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
}, cfg)
|
||||
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("cached.UserAgent = %q, want %q", cached.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if cached.PackageVersion != "0.75.0" {
|
||||
t.Fatalf("cached.PackageVersion = %q, want %q", cached.PackageVersion, "0.75.0")
|
||||
}
|
||||
if cached.OS != "MacOS" || cached.Arch != "arm64" {
|
||||
t.Fatalf("cached platform = %s/%s, want %s/%s", cached.OS, cached.Arch, "MacOS", "arm64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_ThirdPartyBaselineThenOfficialUpgradeKeepsPinnedPlatform(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-third-party-then-official",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-third-party-then-official",
|
||||
},
|
||||
}
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-third-party-then-official", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.87.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-third-party-then-official", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_DisableDeviceProfileStabilization(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-disable-stability",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-disable-stability",
|
||||
},
|
||||
}
|
||||
|
||||
firstReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(firstReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "Linux", "x64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"lobe-chat/1.0"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.60 (external, cli)", "0.10.0", "v18.0.0", "Windows", "x64")
|
||||
|
||||
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.73.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(lowerReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.61 (external, cli)", "0.73.0", "v24.2.0", "Windows", "x64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LegacyModePreservesConfiguredUserAgentOverrideForClaudeClients(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-legacy-ua-override",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-legacy-ua-override",
|
||||
"header:User-Agent": "config-ua/1.0",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-legacy-ua-override", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "config-ua/1.0", "0.74.0", "v24.3.0", "Linux", "x64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-legacy-runtime-os-arch",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-legacy-runtime-os-arch",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-unset-runtime-os-arch",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-unset-runtime-os-arch",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||
}
|
||||
|
||||
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
|
||||
if claudeDeviceProfileStabilizationEnabled(nil) {
|
||||
t.Fatal("expected nil config to default to disabled stabilization")
|
||||
}
|
||||
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
|
||||
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
343
internal/runtime/executor/codebuddy_executor.go
Normal file
343
internal/runtime/executor/codebuddy_executor.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codeBuddyChatPath = "/v2/chat/completions"
|
||||
codeBuddyAuthType = "codebuddy"
|
||||
)
|
||||
|
||||
// CodeBuddyExecutor handles requests to the CodeBuddy API.
|
||||
type CodeBuddyExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewCodeBuddyExecutor creates a new CodeBuddy executor instance.
|
||||
func NewCodeBuddyExecutor(cfg *config.Config) *CodeBuddyExecutor {
|
||||
return &CodeBuddyExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (e *CodeBuddyExecutor) Identifier() string { return codeBuddyAuthType }
|
||||
|
||||
// codeBuddyCredentials extracts the access token and domain from auth metadata.
|
||||
func codeBuddyCredentials(auth *cliproxyauth.Auth) (accessToken, userID, domain string) {
|
||||
if auth == nil {
|
||||
return "", "", ""
|
||||
}
|
||||
accessToken = metaStringValue(auth.Metadata, "access_token")
|
||||
userID = metaStringValue(auth.Metadata, "user_id")
|
||||
domain = metaStringValue(auth.Metadata, "domain")
|
||||
if domain == "" {
|
||||
domain = codebuddy.DefaultDomain
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PrepareRequest prepares the HTTP request before execution.
|
||||
func (e *CodeBuddyExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
e.applyHeaders(req, accessToken, userID, domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest executes a raw HTTP request.
|
||||
func (e *CodeBuddyExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("codebuddy executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request.
|
||||
func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return resp, fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := codebuddy.BaseURL + codeBuddyChatPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request.
|
||||
func (e *CodeBuddyExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := codebuddy.BaseURL + codeBuddyChatPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
httpResp.Body.Close()
|
||||
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy executor: close stream body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, maxScannerBufferSize)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: httpResp.Header.Clone(),
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh exchanges the CodeBuddy refresh token for a new access token.
|
||||
func (e *CodeBuddyExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("codebuddy: missing auth")
|
||||
}
|
||||
|
||||
refreshToken := metaStringValue(auth.Metadata, "refresh_token")
|
||||
if refreshToken == "" {
|
||||
log.Debugf("codebuddy executor: no refresh token available, skipping refresh")
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
|
||||
authSvc := codebuddy.NewCodeBuddyAuth(e.cfg)
|
||||
storage, err := authSvc.RefreshToken(ctx, accessToken, refreshToken, userID, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: token refresh failed: %w", err)
|
||||
}
|
||||
|
||||
updated := auth.Clone()
|
||||
updated.Metadata["access_token"] = storage.AccessToken
|
||||
if storage.RefreshToken != "" {
|
||||
updated.Metadata["refresh_token"] = storage.RefreshToken
|
||||
}
|
||||
updated.Metadata["expires_in"] = storage.ExpiresIn
|
||||
updated.Metadata["domain"] = storage.Domain
|
||||
if storage.UserID != "" {
|
||||
updated.Metadata["user_id"] = storage.UserID
|
||||
}
|
||||
now := time.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.LastRefreshedAt = now
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// CountTokens is not supported for CodeBuddy.
|
||||
func (e *CodeBuddyExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("codebuddy: count tokens not supported")
|
||||
}
|
||||
|
||||
// applyHeaders sets required headers for CodeBuddy API requests.
|
||||
func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID, domain string) {
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", codebuddy.UserAgent)
|
||||
req.Header.Set("X-User-Id", userID)
|
||||
req.Header.Set("X-Domain", domain)
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
req.Header.Set("X-IDE-Type", "CLI")
|
||||
req.Header.Set("X-IDE-Name", "CLI")
|
||||
req.Header.Set("X-IDE-Version", "2.63.2")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
}
|
||||
125
internal/runtime/executor/codex_continuity.go
Normal file
125
internal/runtime/executor/codex_continuity.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type codexContinuity struct {
|
||||
Key string
|
||||
Source string
|
||||
}
|
||||
|
||||
func metadataString(meta map[string]any, key string) string {
|
||||
if len(meta) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := meta[key]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func principalString(raw any) string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(v.String())
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", raw))
|
||||
}
|
||||
}
|
||||
|
||||
func resolveCodexContinuity(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) codexContinuity {
|
||||
if promptCacheKey := strings.TrimSpace(gjson.GetBytes(req.Payload, "prompt_cache_key").String()); promptCacheKey != "" {
|
||||
return codexContinuity{Key: promptCacheKey, Source: "prompt_cache_key"}
|
||||
}
|
||||
if executionSession := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); executionSession != "" {
|
||||
return codexContinuity{Key: executionSession, Source: "execution_session"}
|
||||
}
|
||||
if ginCtx := ginContextFrom(ctx); ginCtx != nil {
|
||||
if ginCtx.Request != nil {
|
||||
if v := strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")); v != "" {
|
||||
return codexContinuity{Key: v, Source: "idempotency_key"}
|
||||
}
|
||||
}
|
||||
if v, exists := ginCtx.Get("apiKey"); exists && v != nil {
|
||||
if trimmed := principalString(v); trimmed != "" {
|
||||
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+trimmed)).String(), Source: "client_principal"}
|
||||
}
|
||||
}
|
||||
}
|
||||
if auth != nil {
|
||||
if authID := strings.TrimSpace(auth.ID); authID != "" {
|
||||
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:"+authID)).String(), Source: "auth_id"}
|
||||
}
|
||||
}
|
||||
return codexContinuity{}
|
||||
}
|
||||
|
||||
func applyCodexContinuityBody(rawJSON []byte, continuity codexContinuity) []byte {
|
||||
if continuity.Key == "" {
|
||||
return rawJSON
|
||||
}
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", continuity.Key)
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
func applyCodexContinuityHeaders(headers http.Header, continuity codexContinuity) {
|
||||
if headers == nil || continuity.Key == "" {
|
||||
return
|
||||
}
|
||||
headers.Set("session_id", continuity.Key)
|
||||
}
|
||||
|
||||
func logCodexRequestDiagnostics(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, headers http.Header, body []byte, continuity codexContinuity) {
|
||||
if !log.IsLevelEnabled(log.DebugLevel) {
|
||||
return
|
||||
}
|
||||
entry := logWithRequestID(ctx)
|
||||
authID := ""
|
||||
authFile := ""
|
||||
if auth != nil {
|
||||
authID = strings.TrimSpace(auth.ID)
|
||||
authFile = strings.TrimSpace(auth.FileName)
|
||||
}
|
||||
selectedAuthID := metadataString(opts.Metadata, cliproxyexecutor.SelectedAuthMetadataKey)
|
||||
executionSessionID := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey)
|
||||
entry.Debugf(
|
||||
"codex request diagnostics auth_id=%s selected_auth_id=%s auth_file=%s exec_session=%s continuity_source=%s session_id=%s prompt_cache_key=%s prompt_cache_retention=%s store=%t has_instructions=%t reasoning_effort=%s reasoning_summary=%s chatgpt_account_id=%t originator=%s model=%s source_format=%s",
|
||||
authID,
|
||||
selectedAuthID,
|
||||
authFile,
|
||||
executionSessionID,
|
||||
continuity.Source,
|
||||
strings.TrimSpace(headers.Get("session_id")),
|
||||
gjson.GetBytes(body, "prompt_cache_key").String(),
|
||||
gjson.GetBytes(body, "prompt_cache_retention").String(),
|
||||
gjson.GetBytes(body, "store").Bool(),
|
||||
gjson.GetBytes(body, "instructions").Exists(),
|
||||
gjson.GetBytes(body, "reasoning.effort").String(),
|
||||
gjson.GetBytes(body, "reasoning.summary").String(),
|
||||
strings.TrimSpace(headers.Get("Chatgpt-Account-Id")) != "",
|
||||
strings.TrimSpace(headers.Get("Originator")),
|
||||
req.Model,
|
||||
opts.SourceFormat.String(),
|
||||
)
|
||||
}
|
||||
@@ -28,8 +28,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.101.0"
|
||||
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexOriginator = "codex_cli_rs"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
@@ -111,18 +111,19 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -183,7 +184,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||
@@ -222,11 +223,12 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
body, _ = sjson.DeleteBytes(body, "stream")
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -273,7 +275,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -309,19 +311,20 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -387,7 +390,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
@@ -415,6 +418,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
@@ -432,7 +436,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
|
||||
@@ -596,8 +600,9 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
|
||||
func (e *CodexExecutor) cacheHelper(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, url string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) (*http.Request, codexContinuity, error) {
|
||||
var cache codexCache
|
||||
continuity := codexContinuity{}
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
@@ -610,30 +615,26 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
}
|
||||
setCodexCache(key, cache)
|
||||
}
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key")
|
||||
if promptCacheKey.Exists() {
|
||||
cache.ID = promptCacheKey.String()
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
|
||||
}
|
||||
} else if from == "openai" {
|
||||
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
|
||||
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
||||
}
|
||||
continuity = resolveCodexContinuity(ctx, auth, req, opts)
|
||||
cache.ID = continuity.Key
|
||||
}
|
||||
|
||||
if cache.ID != "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||
}
|
||||
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rawJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, continuity, err
|
||||
}
|
||||
if cache.ID != "" {
|
||||
httpReq.Header.Set("Conversation_id", cache.ID)
|
||||
httpReq.Header.Set("Session_id", cache.ID)
|
||||
}
|
||||
return httpReq, nil
|
||||
applyCodexContinuityHeaders(httpReq.Header, continuity)
|
||||
return httpReq, continuity, nil
|
||||
}
|
||||
|
||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
|
||||
@@ -645,8 +646,10 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
|
||||
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
@@ -663,8 +666,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
isAPIKey = true
|
||||
}
|
||||
}
|
||||
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
|
||||
r.Header.Set("Originator", originator)
|
||||
} else if !isAPIKey {
|
||||
r.Header.Set("Originator", codexOriginator)
|
||||
}
|
||||
if !isAPIKey {
|
||||
r.Header.Set("Originator", "codex_cli_rs")
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
||||
r.Header.Set("Chatgpt-Account-Id", accountID)
|
||||
@@ -679,13 +686,39 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
errCode := statusCode
|
||||
if isCodexModelCapacityError(body) {
|
||||
errCode = http.StatusTooManyRequests
|
||||
}
|
||||
err := statusErr{code: errCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func isCodexModelCapacityError(errorBody []byte) bool {
|
||||
if len(errorBody) == 0 {
|
||||
return false
|
||||
}
|
||||
candidates := []string{
|
||||
gjson.GetBytes(errorBody, "error.message").String(),
|
||||
gjson.GetBytes(errorBody, "message").String(),
|
||||
string(errorBody),
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
lower := strings.ToLower(strings.TrimSpace(candidate))
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(lower, "selected model is at capacity") ||
|
||||
strings.Contains(lower, "model is at capacity. please try a different model") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -27,7 +28,7 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
}
|
||||
url := "https://example.com/responses"
|
||||
|
||||
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
httpReq, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
@@ -42,14 +43,14 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
if gotKey != expectedKey {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
||||
}
|
||||
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
|
||||
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
|
||||
if gotSession := httpReq.Header.Get("session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("session_id = %q, want %q", gotSession, expectedKey)
|
||||
}
|
||||
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
||||
if got := httpReq.Header.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
|
||||
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
httpReq2, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error (second call): %v", err)
|
||||
}
|
||||
@@ -62,3 +63,118 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIResponses_PreservesPromptCacheRetention(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
url := "https://example.com/responses"
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.3-codex",
|
||||
Payload: []byte(`{"model":"gpt-5.3-codex","prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
|
||||
}
|
||||
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true,"prompt_cache_retention":"persistent"}`)
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai-response"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "cache-key-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := gjson.GetBytes(body, "prompt_cache_retention").String(); got != "persistent" {
|
||||
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != "cache-key-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := httpReq.Header.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_UsesExecutionSessionForContinuity(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4"}`),
|
||||
}
|
||||
opts := cliproxyexecutor.Options{Metadata: map[string]any{cliproxyexecutor.ExecutionSessionMetadataKey: "exec-session-1"}}
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai"), "https://example.com/responses", req, opts, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "exec-session-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "exec-session-1")
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != "exec-session-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "exec-session-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_FallsBackToStableAuthID(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4"}`),
|
||||
}
|
||||
auth := &cliproxyauth.Auth{ID: "codex-auth-1", Provider: "codex"}
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), auth, sdktranslator.FromString("openai"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
expected := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:codex-auth-1")).String()
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != expected {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, expected)
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != expected {
|
||||
t.Fatalf("session_id = %q, want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_ClaudePreservesCacheContinuity(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "claude-3-7-sonnet",
|
||||
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
|
||||
}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
|
||||
httpReq, continuity, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("claude"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
if continuity.Key == "" {
|
||||
t.Fatal("continuity.Key = empty, want non-empty")
|
||||
}
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != continuity.Key {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != continuity.Key {
|
||||
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
|
||||
|
||||
err := newCodexStatusErr(http.StatusBadRequest, body)
|
||||
|
||||
if got := err.StatusCode(); got != http.StatusTooManyRequests {
|
||||
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
|
||||
}
|
||||
if err.RetryAfter() != nil {
|
||||
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
@@ -178,7 +178,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
@@ -190,7 +189,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
@@ -209,6 +208,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -343,7 +343,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: out}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
@@ -385,7 +385,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
@@ -403,6 +403,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -592,7 +593,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
line := encodeCodexWebsocketAsSSE(payload)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, ¶m)
|
||||
for i := range chunks {
|
||||
if !send(cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}) {
|
||||
if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) {
|
||||
terminateReason = "context_done"
|
||||
terminateErr = ctx.Err()
|
||||
return
|
||||
@@ -761,13 +762,14 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecutor.Request, rawJSON []byte) ([]byte, http.Header) {
|
||||
func applyCodexPromptCacheHeaders(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) ([]byte, http.Header, codexContinuity) {
|
||||
headers := http.Header{}
|
||||
if len(rawJSON) == 0 {
|
||||
return rawJSON, headers
|
||||
return rawJSON, headers, codexContinuity{}
|
||||
}
|
||||
|
||||
var cache codexCache
|
||||
continuity := codexContinuity{}
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
@@ -781,20 +783,22 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
||||
}
|
||||
setCodexCache(key, cache)
|
||||
}
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
|
||||
cache.ID = promptCacheKey.String()
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
|
||||
}
|
||||
} else if from == "openai" {
|
||||
continuity = resolveCodexContinuity(ctx, auth, req, opts)
|
||||
cache.ID = continuity.Key
|
||||
}
|
||||
|
||||
if cache.ID != "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||
headers.Set("Conversation_id", cache.ID)
|
||||
headers.Set("Session_id", cache.ID)
|
||||
}
|
||||
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
|
||||
applyCodexContinuityHeaders(headers, continuity)
|
||||
|
||||
return rawJSON, headers
|
||||
return rawJSON, headers, continuity
|
||||
}
|
||||
|
||||
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
|
||||
@@ -814,9 +818,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "Version", "")
|
||||
|
||||
misc.EnsureHeader(headers, ginHeaders, "Version", codexClientVersion)
|
||||
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
|
||||
if betaHeader == "" && ginHeaders != nil {
|
||||
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
|
||||
@@ -825,7 +830,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
betaHeader = codexResponsesWebsocketBetaHeaderValue
|
||||
}
|
||||
headers.Set("OpenAI-Beta", betaHeader)
|
||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(headers, ginHeaders, "session_id", uuid.NewString())
|
||||
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
isAPIKey := false
|
||||
@@ -834,8 +839,12 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
isAPIKey = true
|
||||
}
|
||||
}
|
||||
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
|
||||
headers.Set("Originator", originator)
|
||||
} else if !isAPIKey {
|
||||
headers.Set("Originator", codexOriginator)
|
||||
}
|
||||
if !isAPIKey {
|
||||
headers.Set("Originator", "codex_cli_rs")
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -32,6 +34,49 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexPromptCacheHeaders_PreservesPromptCacheRetention(t *testing.T) {
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5-codex",
|
||||
Payload: []byte(`{"prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
|
||||
}
|
||||
body := []byte(`{"model":"gpt-5-codex","stream":true,"prompt_cache_retention":"persistent"}`)
|
||||
|
||||
updatedBody, headers, _ := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("openai-response"), req, cliproxyexecutor.Options{}, body)
|
||||
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != "cache-key-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_retention").String(); got != "persistent" {
|
||||
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
|
||||
}
|
||||
if got := headers.Get("session_id"); got != "cache-key-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := headers.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexPromptCacheHeaders_ClaudePreservesContinuity(t *testing.T) {
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "claude-3-7-sonnet",
|
||||
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
|
||||
}
|
||||
body := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
|
||||
updatedBody, headers, continuity := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("claude"), req, cliproxyexecutor.Options{}, body)
|
||||
|
||||
if continuity.Key == "" {
|
||||
t.Fatal("continuity.Key = empty, want non-empty")
|
||||
}
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != continuity.Key {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
if got := headers.Get("session_id"); got != continuity.Key {
|
||||
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||
|
||||
@@ -41,9 +86,46 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T)
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
}
|
||||
if got := headers.Get("Version"); got != "" {
|
||||
t.Fatalf("Version = %q, want empty", got)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
if got := headers.Get("X-Codex-Turn-Metadata"); got != "" {
|
||||
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
|
||||
}
|
||||
if got := headers.Get("X-Client-Request-Id"); got != "" {
|
||||
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
ctx := contextWithGinHeaders(map[string]string{
|
||||
"Originator": "Codex Desktop",
|
||||
"Version": "0.115.0-alpha.27",
|
||||
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
|
||||
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
|
||||
})
|
||||
|
||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
|
||||
|
||||
if got := headers.Get("Originator"); got != "Codex Desktop" {
|
||||
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
|
||||
}
|
||||
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
|
||||
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
|
||||
}
|
||||
if got := headers.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
|
||||
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
|
||||
}
|
||||
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
|
||||
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
||||
@@ -177,6 +259,57 @@ func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest() error = %v", err)
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
req = req.WithContext(contextWithGinHeaders(map[string]string{
|
||||
"Originator": "Codex Desktop",
|
||||
"Version": "0.115.0-alpha.27",
|
||||
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
|
||||
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
|
||||
}))
|
||||
|
||||
applyCodexHeaders(req, auth, "oauth-token", true, nil)
|
||||
|
||||
if got := req.Header.Get("Originator"); got != "Codex Desktop" {
|
||||
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
|
||||
}
|
||||
if got := req.Header.Get("Version"); got != "0.115.0-alpha.27" {
|
||||
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
|
||||
}
|
||||
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
|
||||
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
|
||||
}
|
||||
if got := req.Header.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
|
||||
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexHeadersDoesNotInjectClientOnlyHeadersByDefault(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest() error = %v", err)
|
||||
}
|
||||
|
||||
applyCodexHeaders(req, nil, "oauth-token", true, nil)
|
||||
|
||||
if got := req.Header.Get("Version"); got != "" {
|
||||
t.Fatalf("Version = %q, want empty", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != "" {
|
||||
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Client-Request-Id"); got != "" {
|
||||
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func contextWithGinHeaders(headers map[string]string) context.Context {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
1719
internal/runtime/executor/cursor_executor.go
Normal file
1719
internal/runtime/executor/cursor_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -224,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -401,14 +401,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -430,12 +430,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||
|
||||
@@ -544,7 +544,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
@@ -811,18 +811,18 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
||||
|
||||
if !hasInlineData {
|
||||
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
||||
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
|
||||
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||
newPartsJson := `[]`
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
|
||||
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
|
||||
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||
newPartsJson := []byte(`[]`)
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
|
||||
|
||||
parts := contentArray[0].Get("parts").Array()
|
||||
for j := 0; j < len(parts); j++ {
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", newPartsJson)
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -321,12 +321,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -415,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||
@@ -527,18 +527,18 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
||||
|
||||
if !hasInlineData {
|
||||
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
||||
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
|
||||
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||
newPartsJson := `[]`
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
|
||||
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
|
||||
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||
newPartsJson := []byte(`[]`)
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
|
||||
|
||||
parts := contentArray[0].Get("parts").Array()
|
||||
for j := 0; j < len(parts); j++ {
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", newPartsJson)
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -524,7 +524,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -636,12 +636,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -760,12 +760,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -857,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
@@ -941,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
|
||||
@@ -221,13 +221,13 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
|
||||
var param any
|
||||
converted := ""
|
||||
var converted []byte
|
||||
if useResponses && from.String() == "claude" {
|
||||
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
||||
} else {
|
||||
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: converted}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -374,14 +374,14 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
}
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
var chunks [][]byte
|
||||
if useResponses && from.String() == "claude" {
|
||||
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
||||
} else {
|
||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
}
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -577,9 +577,33 @@ func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model
|
||||
return true
|
||||
}
|
||||
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
|
||||
if info := registry.GetGlobalRegistry().GetModelInfo(baseModel, githubCopilotAuthType); info != nil {
|
||||
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
|
||||
}
|
||||
if info := lookupGitHubCopilotStaticModelInfo(baseModel); info != nil {
|
||||
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
|
||||
}
|
||||
return strings.Contains(baseModel, "codex")
|
||||
}
|
||||
|
||||
func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
|
||||
for _, info := range registry.GetStaticModelDefinitionsByChannel(githubCopilotAuthType) {
|
||||
if info != nil && strings.EqualFold(info.ID, model) {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func containsEndpoint(endpoints []string, endpoint string) bool {
|
||||
for _, item := range endpoints {
|
||||
if item == endpoint {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// flattenAssistantContent converts assistant message content from array format
|
||||
// to a joined string. GitHub Copilot requires assistant content as a string;
|
||||
// sending it as an array causes Claude models to re-answer all previous prompts.
|
||||
@@ -977,7 +1001,7 @@ type githubCopilotResponsesStreamState struct {
|
||||
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
||||
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) []byte {
|
||||
root := gjson.ParseBytes(data)
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
||||
@@ -1067,10 +1091,10 @@ func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||
}
|
||||
return out
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
|
||||
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &githubCopilotResponsesStreamState{
|
||||
TextBlockIndex: -1,
|
||||
@@ -1092,7 +1116,10 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
}
|
||||
|
||||
event := gjson.GetBytes(payload, "type").String()
|
||||
results := make([]string, 0, 4)
|
||||
results := make([][]byte, 0, 4)
|
||||
appendResult := func(chunk string) {
|
||||
results = append(results, []byte(chunk))
|
||||
}
|
||||
ensureMessageStart := func() {
|
||||
if state.MessageStarted {
|
||||
return
|
||||
@@ -1100,7 +1127,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
|
||||
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
|
||||
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
|
||||
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
|
||||
appendResult("event: message_start\ndata: " + messageStart + "\n\n")
|
||||
state.MessageStarted = true
|
||||
}
|
||||
startTextBlockIfNeeded := func() {
|
||||
@@ -1113,7 +1140,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
}
|
||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
|
||||
state.TextBlockStarted = true
|
||||
}
|
||||
stopTextBlockIfNeeded := func() {
|
||||
@@ -1122,7 +1149,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
|
||||
state.TextBlockStarted = false
|
||||
state.TextBlockIndex = -1
|
||||
}
|
||||
@@ -1152,7 +1179,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
|
||||
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
|
||||
appendResult("event: content_block_delta\ndata: " + contentDelta + "\n\n")
|
||||
}
|
||||
case "response.reasoning_summary_part.added":
|
||||
ensureMessageStart()
|
||||
@@ -1161,7 +1188,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
state.NextContentIndex++
|
||||
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
||||
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
|
||||
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
|
||||
appendResult("event: content_block_start\ndata: " + thinkingStart + "\n\n")
|
||||
case "response.reasoning_summary_text.delta":
|
||||
if state.ReasoningActive {
|
||||
delta := gjson.GetBytes(payload, "delta").String()
|
||||
@@ -1169,14 +1196,14 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
||||
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
|
||||
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
|
||||
appendResult("event: content_block_delta\ndata: " + thinkingDelta + "\n\n")
|
||||
}
|
||||
}
|
||||
case "response.reasoning_summary_part.done":
|
||||
if state.ReasoningActive {
|
||||
thinkingStop := `{"type":"content_block_stop","index":0}`
|
||||
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
|
||||
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
|
||||
appendResult("event: content_block_stop\ndata: " + thinkingStop + "\n\n")
|
||||
state.ReasoningActive = false
|
||||
}
|
||||
case "response.output_item.added":
|
||||
@@ -1204,7 +1231,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
|
||||
case "response.output_item.delta":
|
||||
item := gjson.GetBytes(payload, "item")
|
||||
if item.Get("type").String() != "function_call" {
|
||||
@@ -1224,7 +1251,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
||||
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
|
||||
case "response.function_call_arguments.delta":
|
||||
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
|
||||
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
|
||||
@@ -1241,7 +1268,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
||||
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
|
||||
case "response.output_item.done":
|
||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||
break
|
||||
@@ -1252,7 +1279,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
|
||||
case "response.completed":
|
||||
ensureMessageStart()
|
||||
stopTextBlockIfNeeded()
|
||||
@@ -1276,8 +1303,8 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
if cachedTokens > 0 {
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
|
||||
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
appendResult("event: message_delta\ndata: " + messageDelta + "\n\n")
|
||||
appendResult("event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
state.MessageStopSent = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -70,6 +71,29 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected responses-only registry model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "github-copilot-test-client"
|
||||
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
|
||||
ID: "gpt-5.4",
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
}})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||
|
||||
@@ -30,12 +30,20 @@ const (
|
||||
gitLabChatEndpoint = "/api/v4/chat/completions"
|
||||
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
|
||||
gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming"
|
||||
gitLabContext1MBeta = "context-1m-2025-08-07"
|
||||
gitLabNativeUserAgent = "CLIProxyAPIPlus/GitLab-Duo"
|
||||
)
|
||||
|
||||
type GitLabExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type gitLabCatalogModel struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
Provider string
|
||||
}
|
||||
|
||||
type gitLabPrompt struct {
|
||||
Instruction string
|
||||
FileName string
|
||||
@@ -53,6 +61,23 @@ type gitLabOpenAIStreamState struct {
|
||||
Finished bool
|
||||
}
|
||||
|
||||
var gitLabAgenticCatalog = []gitLabCatalogModel{
|
||||
{ID: "duo-chat-gpt-5-1", DisplayName: "GitLab Duo (GPT-5.1)", Provider: "openai"},
|
||||
{ID: "duo-chat-opus-4-6", DisplayName: "GitLab Duo (Claude Opus 4.6)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-opus-4-5", DisplayName: "GitLab Duo (Claude Opus 4.5)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-sonnet-4-6", DisplayName: "GitLab Duo (Claude Sonnet 4.6)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-sonnet-4-5", DisplayName: "GitLab Duo (Claude Sonnet 4.5)", Provider: "anthropic"},
|
||||
{ID: "duo-chat-gpt-5-mini", DisplayName: "GitLab Duo (GPT-5 Mini)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-2", DisplayName: "GitLab Duo (GPT-5.2)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-2-codex", DisplayName: "GitLab Duo (GPT-5.2 Codex)", Provider: "openai"},
|
||||
{ID: "duo-chat-gpt-5-codex", DisplayName: "GitLab Duo (GPT-5 Codex)", Provider: "openai"},
|
||||
{ID: "duo-chat-haiku-4-5", DisplayName: "GitLab Duo (Claude Haiku 4.5)", Provider: "anthropic"},
|
||||
}
|
||||
|
||||
var gitLabModelAliases = map[string]string{
|
||||
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
|
||||
}
|
||||
|
||||
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
|
||||
return &GitLabExecutor{cfg: cfg}
|
||||
}
|
||||
@@ -249,12 +274,12 @@ func (e *GitLabExecutor) nativeGateway(
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, req.Model); ok {
|
||||
nativeReq := req
|
||||
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
||||
return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true
|
||||
}
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, req.Model); ok {
|
||||
nativeReq := req
|
||||
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
||||
return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true
|
||||
@@ -263,10 +288,10 @@ func (e *GitLabExecutor) nativeGateway(
|
||||
}
|
||||
|
||||
func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, ""); ok {
|
||||
return NewClaudeExecutor(e.cfg), nativeAuth
|
||||
}
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
|
||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, ""); ok {
|
||||
return NewCodexExecutor(e.cfg), nativeAuth
|
||||
}
|
||||
return nil, nil
|
||||
@@ -664,7 +689,7 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
if auth != nil {
|
||||
util.ApplyCustomHeadersFromAttrs(req, auth.Attributes)
|
||||
}
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
for key, value := range gitLabGatewayHeaders(auth, "") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -672,34 +697,40 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
}
|
||||
}
|
||||
|
||||
func gitLabGatewayHeaders(auth *cliproxyauth.Auth) map[string]string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := auth.Metadata["duo_gateway_headers"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
func gitLabGatewayHeaders(auth *cliproxyauth.Auth, targetProvider string) map[string]string {
|
||||
out := make(map[string]string)
|
||||
switch typed := raw.(type) {
|
||||
case map[string]string:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key != "" && value != "" {
|
||||
out[key] = value
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
raw, ok := auth.Metadata["duo_gateway_headers"]
|
||||
if ok {
|
||||
switch typed := raw.(type) {
|
||||
case map[string]string:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key != "" && value != "" {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
strValue := strings.TrimSpace(fmt.Sprint(value))
|
||||
if strValue != "" {
|
||||
out[key] = strValue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
for key, value := range typed {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
strValue := strings.TrimSpace(fmt.Sprint(value))
|
||||
if strValue != "" {
|
||||
out[key] = strValue
|
||||
}
|
||||
}
|
||||
if _, ok := out["User-Agent"]; !ok {
|
||||
out["User-Agent"] = gitLabNativeUserAgent
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(targetProvider), "openai") {
|
||||
if _, ok := out["anthropic-beta"]; !ok {
|
||||
out["anthropic-beta"] = gitLabContext1MBeta
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
@@ -989,8 +1020,8 @@ func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64)
|
||||
return promptTokens, int64(completionCount)
|
||||
}
|
||||
|
||||
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesAnthropicGateway(auth) {
|
||||
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesAnthropicGateway(auth, requestedModel) {
|
||||
return nil, false
|
||||
}
|
||||
baseURL := gitLabAnthropicGatewayBaseURL(auth)
|
||||
@@ -1006,7 +1037,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
|
||||
}
|
||||
nativeAuth.Attributes["api_key"] = token
|
||||
nativeAuth.Attributes["base_url"] = baseURL
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
nativeAuth.Attributes["gitlab_duo_force_context_1m"] = "true"
|
||||
for key, value := range gitLabGatewayHeaders(auth, "anthropic") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1015,8 +1047,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
|
||||
return nativeAuth, true
|
||||
}
|
||||
|
||||
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesOpenAIGateway(auth) {
|
||||
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
|
||||
if !gitLabUsesOpenAIGateway(auth, requestedModel) {
|
||||
return nil, false
|
||||
}
|
||||
baseURL := gitLabOpenAIGatewayBaseURL(auth)
|
||||
@@ -1032,7 +1064,7 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
|
||||
}
|
||||
nativeAuth.Attributes["api_key"] = token
|
||||
nativeAuth.Attributes["base_url"] = baseURL
|
||||
for key, value := range gitLabGatewayHeaders(auth) {
|
||||
for key, value := range gitLabGatewayHeaders(auth, "openai") {
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1041,34 +1073,41 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
|
||||
return nativeAuth, true
|
||||
}
|
||||
|
||||
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth) bool {
|
||||
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
provider = inferGitLabProviderFromModel(modelName)
|
||||
}
|
||||
provider := gitLabGatewayProvider(auth, requestedModel)
|
||||
return provider == "anthropic" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
||||
}
|
||||
|
||||
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth) bool {
|
||||
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
provider = inferGitLabProviderFromModel(modelName)
|
||||
}
|
||||
provider := gitLabGatewayProvider(auth, requestedModel)
|
||||
return provider == "openai" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
||||
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
||||
}
|
||||
|
||||
func gitLabGatewayProvider(auth *cliproxyauth.Auth, requestedModel string) string {
|
||||
modelName := strings.TrimSpace(gitLabResolvedModel(auth, requestedModel))
|
||||
if provider := inferGitLabProviderFromModel(modelName); provider != "" {
|
||||
return provider
|
||||
}
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||
if provider == "" {
|
||||
provider = inferGitLabProviderFromModel(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
func inferGitLabProviderFromModel(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
switch {
|
||||
@@ -1151,6 +1190,9 @@ func gitLabBaseURL(auth *cliproxyauth.Auth) string {
|
||||
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
|
||||
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
|
||||
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
|
||||
if mapped, ok := gitLabModelAliases[strings.ToLower(requested)]; ok && strings.TrimSpace(mapped) != "" {
|
||||
return mapped
|
||||
}
|
||||
return requested
|
||||
}
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
@@ -1277,8 +1319,8 @@ func gitLabAuthKind(method string) string {
|
||||
}
|
||||
|
||||
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
||||
models := make([]*registry.ModelInfo, 0, 4)
|
||||
seen := make(map[string]struct{}, 4)
|
||||
models := make([]*registry.ModelInfo, 0, len(gitLabAgenticCatalog)+4)
|
||||
seen := make(map[string]struct{}, len(gitLabAgenticCatalog)+4)
|
||||
addModel := func(id, displayName, provider string) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
@@ -1302,6 +1344,18 @@ func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
||||
}
|
||||
|
||||
addModel("gitlab-duo", "GitLab Duo", "gitlab")
|
||||
for _, model := range gitLabAgenticCatalog {
|
||||
addModel(model.ID, model.DisplayName, model.Provider)
|
||||
}
|
||||
for alias, upstream := range gitLabModelAliases {
|
||||
target := strings.TrimSpace(upstream)
|
||||
displayName := "GitLab Duo Alias"
|
||||
provider := strings.TrimSpace(inferGitLabProviderFromModel(target))
|
||||
if provider != "" {
|
||||
displayName = fmt.Sprintf("GitLab Duo Alias (%s)", provider)
|
||||
}
|
||||
addModel(alias, displayName, provider)
|
||||
}
|
||||
if auth == nil {
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -217,6 +217,69 @@ func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotBetaHeader = r.Header.Get("anthropic-beta")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "duo-chat-gpt-5-codex",
|
||||
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotBetaHeader != gitLabContext1MBeta {
|
||||
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
|
||||
}
|
||||
if gotUserAgent != gitLabNativeUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||
}
|
||||
if gotModel != "duo-chat-gpt-5-codex" {
|
||||
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
|
||||
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
@@ -251,13 +314,12 @@ func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
ID: "gitlab-auth.json",
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"oauth_client_secret": "client-secret",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -397,9 +459,11 @@ func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotPath, gotBetaHeader, gotUserAgent string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotBetaHeader = r.Header.Get("Anthropic-Beta")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: message_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||
@@ -441,6 +505,12 @@ func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
|
||||
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
|
||||
}
|
||||
if gotUserAgent != gitLabNativeUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
||||
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
@@ -169,7 +169,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -281,7 +281,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
@@ -315,7 +315,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key.
|
||||
|
||||
@@ -161,7 +161,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -271,12 +271,12 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
|
||||
@@ -89,6 +89,13 @@ var endpointAliases = map[string]string{
|
||||
"cli": "amazonq",
|
||||
}
|
||||
|
||||
func enqueueTranslatedSSE(out chan<- cliproxyexecutor.StreamChunk, chunk []byte) {
|
||||
if len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: append(bytes.Clone(chunk), '\n', '\n')}
|
||||
}
|
||||
|
||||
// retryConfig holds configuration for socket retry logic.
|
||||
// Based on kiro2Api Python implementation patterns.
|
||||
type retryConfig struct {
|
||||
@@ -2573,9 +2580,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Send tool input as delta
|
||||
@@ -2583,18 +2588,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Close block
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
hasToolUses = true
|
||||
@@ -2664,9 +2665,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
messageStartSent = true
|
||||
}
|
||||
@@ -2916,9 +2915,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
lastReportedOutputTokens = currentOutputTokens
|
||||
@@ -2939,17 +2936,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
continue
|
||||
@@ -2978,18 +2971,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
// Send thinking delta
|
||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
accumulatedThinkingContent.WriteString(thinkingText)
|
||||
}
|
||||
@@ -2998,9 +2987,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isThinkingBlockOpen = false
|
||||
}
|
||||
@@ -3029,17 +3016,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
accumulatedThinkingContent.WriteString(processContent)
|
||||
}
|
||||
@@ -3058,9 +3041,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isThinkingBlockOpen = false
|
||||
}
|
||||
@@ -3071,18 +3052,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
// Send text delta
|
||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
// Close text block before entering thinking
|
||||
@@ -3090,9 +3067,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
@@ -3120,17 +3095,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3158,9 +3129,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
@@ -3171,9 +3140,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Send input_json_delta with the tool input
|
||||
@@ -3186,9 +3153,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3197,9 +3162,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3239,9 +3202,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
@@ -3254,9 +3215,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3264,9 +3223,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Accumulate for token counting
|
||||
@@ -3298,9 +3255,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
@@ -3310,9 +3265,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
if tu.Input != nil {
|
||||
@@ -3323,9 +3276,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3333,9 +3284,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3522,9 +3471,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3609,18 +3556,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Send message_stop event separately
|
||||
msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent()
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
// reporter.publish is called via defer
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -290,7 +290,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
// Pass through translator; it yields one or more chunks for the target schema.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
@@ -330,7 +330,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
|
||||
}
|
||||
|
||||
// Refresh is a no-op for API-key based compatibility providers.
|
||||
|
||||
@@ -305,7 +305,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -421,12 +421,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -461,7 +461,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
|
||||
@@ -73,17 +73,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Failed: failed,
|
||||
Detail: detail,
|
||||
})
|
||||
usage.PublishRecord(ctx, r.buildRecord(detail, failed))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -96,20 +86,39 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Failed: false,
|
||||
Detail: usage.Detail{},
|
||||
})
|
||||
usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
||||
if r == nil {
|
||||
return usage.Record{Detail: detail, Failed: failed}
|
||||
}
|
||||
return usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Latency: r.latency(),
|
||||
Failed: failed,
|
||||
Detail: detail,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageReporter) latency() time.Duration {
|
||||
if r == nil || r.requestedAt.IsZero() {
|
||||
return 0
|
||||
}
|
||||
latency := time.Since(r.requestedAt)
|
||||
if latency < 0 {
|
||||
return 0
|
||||
}
|
||||
return latency
|
||||
}
|
||||
|
||||
func apiKeyFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package executor
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
||||
@@ -41,3 +46,19 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
|
||||
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
||||
reporter := &usageReporter{
|
||||
provider: "openai",
|
||||
model: "gpt-5.4",
|
||||
requestedAt: time.Now().Add(-1500 * time.Millisecond),
|
||||
}
|
||||
|
||||
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
|
||||
if record.Latency < time.Second {
|
||||
t.Fatalf("latency = %v, want >= 1s", record.Latency)
|
||||
}
|
||||
if record.Latency > 3*time.Second {
|
||||
t.Fatalf("latency = %v, want <= 3s", record.Latency)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -39,35 +40,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
var systemInstructionJSON []byte
|
||||
hasSystemInstruction := false
|
||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
partJSON := `{}`
|
||||
partJSON := []byte(`{}`)
|
||||
if systemPrompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", systemPrompt)
|
||||
}
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
|
||||
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", partJSON)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
|
||||
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
|
||||
systemInstructionJSON = []byte(`{"role":"user","parts":[{"text":""}]}`)
|
||||
systemInstructionJSON, _ = sjson.SetBytes(systemInstructionJSON, "parts.0.text", systemResult.String())
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
|
||||
// contents
|
||||
contentsJSON := "[]"
|
||||
contentsJSON := []byte(`[]`)
|
||||
hasContents := false
|
||||
|
||||
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
|
||||
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
|
||||
toolNameByID := make(map[string]string)
|
||||
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
@@ -83,8 +88,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContentJSON := `{"role":"","parts":[]}`
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
|
||||
clientContentJSON := []byte(`{"role":"","parts":[]}`)
|
||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "role", role)
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
@@ -143,15 +148,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
|
||||
// Valid signature, send as thought block
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "thought", true)
|
||||
if thinkingText != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
|
||||
}
|
||||
// Always include "text" field — Google Antigravity API requires it
|
||||
// even for redacted thinking where the text is empty.
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
// Skip empty text parts to avoid Gemini API error:
|
||||
@@ -159,17 +164,21 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if prompt == "" {
|
||||
continue
|
||||
}
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||
// Antigravity API validates signatures, so dummy values are rejected.
|
||||
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
|
||||
argsResult := contentResult.Get("input")
|
||||
functionID := contentResult.Get("id").String()
|
||||
|
||||
if functionID != "" && functionName != "" {
|
||||
toolNameByID[functionID] = functionName
|
||||
}
|
||||
|
||||
// Handle both object and string input formats
|
||||
var argsRaw string
|
||||
if argsResult.IsObject() {
|
||||
@@ -183,138 +192,147 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
|
||||
if argsRaw != "" {
|
||||
partJSON := `{}`
|
||||
partJSON := []byte(`{}`)
|
||||
|
||||
// Use skip_thought_signature_validator for tool calls without valid thinking signature
|
||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||
// and also works for Claude through Antigravity API
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||
} else {
|
||||
// No valid signature - use skip sentinel to bypass validation
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", skipSentinel)
|
||||
}
|
||||
|
||||
if functionID != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.id", functionID)
|
||||
}
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRawBytes(partJSON, "functionCall.args", []byte(argsRaw))
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
funcName, ok := toolNameByID[toolCallID]
|
||||
if !ok {
|
||||
// Fallback: derive a semantic name from the ID by stripping
|
||||
// the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
|
||||
// Only use the raw ID as a last resort when the heuristic produces an empty string.
|
||||
parts := strings.Split(toolCallID, "-")
|
||||
if len(parts) > 2 {
|
||||
funcName = strings.Join(parts[:len(parts)-2], "-")
|
||||
}
|
||||
if funcName == "" {
|
||||
funcName = toolCallID
|
||||
}
|
||||
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
|
||||
}
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
functionResponseJSON := `{}`
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
|
||||
functionResponseJSON := []byte(`{}`)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", util.SanitizeFunctionName(funcName))
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
responseData = functionResponseResult.String()
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
nonImageCount := 0
|
||||
lastNonImageRaw := ""
|
||||
filteredJSON := "[]"
|
||||
imagePartsJSON := "[]"
|
||||
filteredJSON := []byte(`[]`)
|
||||
imagePartsJSON := []byte(`[]`)
|
||||
for _, fr := range frResults {
|
||||
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
inlineDataJSON := []byte(`{}`)
|
||||
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := fr.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
imagePartJSON := []byte(`{}`)
|
||||
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
|
||||
continue
|
||||
}
|
||||
|
||||
nonImageCount++
|
||||
lastNonImageRaw = fr.Raw
|
||||
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
||||
filteredJSON, _ = sjson.SetRawBytes(filteredJSON, "-1", []byte(fr.Raw))
|
||||
}
|
||||
|
||||
if nonImageCount == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(lastNonImageRaw))
|
||||
} else if nonImageCount > 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", filteredJSON)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
// Place image data inside functionResponse.parts as inlineData
|
||||
// instead of as sibling parts in the outer content, to avoid
|
||||
// base64 data bloating the text context.
|
||||
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
if gjson.GetBytes(imagePartsJSON, "#").Int() > 0 {
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
inlineDataJSON := []byte(`{}`)
|
||||
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON := "[]"
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
imagePartJSON := []byte(`{}`)
|
||||
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON := []byte(`[]`)
|
||||
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
|
||||
}
|
||||
} else if functionResponseResult.Raw != "" {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
|
||||
} else {
|
||||
// Content field is missing entirely — .Raw is empty which
|
||||
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetRawBytes(partJSON, "functionResponse", functionResponseJSON)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
||||
sourceResult := contentResult.Get("source")
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
inlineDataJSON := []byte(`{}`)
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetRawBytes(partJSON, "inlineData", inlineDataJSON)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder parts for 'model' role to ensure thinking block is first
|
||||
if role == "model" {
|
||||
partsResult := gjson.Get(clientContentJSON, "parts")
|
||||
partsResult := gjson.GetBytes(clientContentJSON, "parts")
|
||||
if partsResult.IsArray() {
|
||||
parts := partsResult.Array()
|
||||
var thinkingParts []gjson.Result
|
||||
@@ -336,7 +354,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
for _, p := range otherParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
|
||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -344,33 +362,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// Skip messages with empty parts array to avoid Gemini API error:
|
||||
// "required oneof field 'data' must have one initialized field"
|
||||
partsCheck := gjson.Get(clientContentJSON, "parts")
|
||||
partsCheck := gjson.GetBytes(clientContentJSON, "parts")
|
||||
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
partJSON := `{}`
|
||||
partJSON := []byte(`{}`)
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tools
|
||||
toolsJSON := ""
|
||||
var toolsJSON []byte
|
||||
toolDeclCount := 0
|
||||
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
toolsJSON = `[{"functionDeclarations":[]}]`
|
||||
toolsJSON = []byte(`[{"functionDeclarations":[]}]`)
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
@@ -378,23 +396,24 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
// Sanitize the input schema for Antigravity API compatibility
|
||||
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
for toolKey := range gjson.Parse(tool).Map() {
|
||||
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
|
||||
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
|
||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||
for toolKey := range gjson.ParseBytes(tool).Map() {
|
||||
if util.InArray(allowedToolKeys, toolKey) {
|
||||
continue
|
||||
}
|
||||
tool, _ = sjson.Delete(tool, toolKey)
|
||||
tool, _ = sjson.DeleteBytes(tool, toolKey)
|
||||
}
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolDeclCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out := []byte(`{"model":"","request":{"contents":[]}}`)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Inject interleaved thinking hint when both tools and thinking are active
|
||||
hasTools := toolDeclCount > 0
|
||||
@@ -408,27 +427,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
if hasSystemInstruction {
|
||||
// Append hint as a new part to existing system instruction
|
||||
hintPart := `{"text":""}`
|
||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
||||
hintPart := []byte(`{"text":""}`)
|
||||
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
|
||||
} else {
|
||||
// Create new system instruction with hint
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
hintPart := `{"text":""}`
|
||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
||||
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
|
||||
hintPart := []byte(`{"text":""}`)
|
||||
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasSystemInstruction {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
|
||||
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstructionJSON)
|
||||
}
|
||||
if hasContents {
|
||||
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents", contentsJSON)
|
||||
}
|
||||
if toolDeclCount > 0 {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools", toolsJSON)
|
||||
}
|
||||
|
||||
// tool_choice
|
||||
@@ -445,15 +464,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
switch toolChoiceType {
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||
case "any":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
case "tool":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -464,8 +483,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive", "auto":
|
||||
// For adaptive thinking:
|
||||
@@ -477,28 +496,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
}
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", v.Num)
|
||||
}
|
||||
|
||||
outBytes := []byte(out)
|
||||
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
|
||||
out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
|
||||
|
||||
return outBytes
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -365,6 +365,17 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "get_weather-call-123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -382,13 +393,177 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function response conversion
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Error("functionResponse should exist")
|
||||
}
|
||||
if funcResp.Get("id").String() != "get_weather-call-123" {
|
||||
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
|
||||
}
|
||||
if funcResp.Get("name").String() != "get_weather" {
|
||||
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"name": "Glob",
|
||||
"input": {"pattern": "**/*.py"}
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"content": "file1.py\nfile2.py"
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||
"content": "total 10"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp0.Exists() {
|
||||
t.Fatal("first functionResponse should exist")
|
||||
}
|
||||
if got := funcResp0.Get("name").String(); got != "Glob" {
|
||||
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
|
||||
}
|
||||
|
||||
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
|
||||
if !funcResp1.Exists() {
|
||||
t.Fatal("second functionResponse should exist")
|
||||
}
|
||||
if got := funcResp1.Get("name").String(); got != "Bash" {
|
||||
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "Read-1773420180464065165-1327",
|
||||
"name": "Read",
|
||||
"input": {"file_path": "/tmp/test.py"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-1773420180464065165-1327",
|
||||
"content": "file content here"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
if got := funcResp.Get("name").String(); got != "Read" {
|
||||
t.Errorf("Expected name 'Read', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "get_weather-call-123",
|
||||
"content": "22C sunny"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
if got := funcResp.Get("name").String(); got != "get_weather" {
|
||||
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"content": "result data"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
got := funcResp.Get("name").String()
|
||||
if got == "" {
|
||||
t.Error("functionResponse.name must not be empty")
|
||||
}
|
||||
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
|
||||
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -43,6 +44,10 @@ type Params struct {
|
||||
|
||||
// Signature caching support
|
||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||
|
||||
// Reverse map: sanitized Gemini function name → original Claude tool name.
|
||||
// Populated lazily on the first response chunk from the original request JSON.
|
||||
ToolNameMap map[string]string
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -63,13 +68,14 @@ var toolUseIDCounter uint64
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
|
||||
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
|
||||
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &Params{
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
@@ -77,44 +83,44 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
params := (*param).(*Params)
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
output := ""
|
||||
output := make([]byte, 0, 256)
|
||||
// Only send final events if we have actually output content
|
||||
if params.HasContent {
|
||||
appendFinalEvents(params, &output, true)
|
||||
return []string{
|
||||
output + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
|
||||
}
|
||||
output = translatorcommon.AppendSSEEventString(output, "message_stop", `{"type":"message_stop"}`, 3)
|
||||
return [][]byte{output}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
output := ""
|
||||
output := make([]byte, 0, 1024)
|
||||
appendEvent := func(event, payload string) {
|
||||
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
|
||||
}
|
||||
|
||||
// Initialize the streaming session with a message_start event
|
||||
// This is only sent for the very first response chunk to establish the streaming session
|
||||
if !params.HasFirstResponse {
|
||||
output = "event: message_start\n"
|
||||
|
||||
// Create the initial message structure with default values according to Claude Code API specification
|
||||
// This follows the Claude Code API specification for streaming message initialization
|
||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||
messageStartTemplate := []byte(`{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`)
|
||||
|
||||
// Use cpaUsageMetadata within the message_start event for Claude.
|
||||
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
||||
}
|
||||
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
||||
}
|
||||
|
||||
// Override default values with actual response metadata if available from the Gemini CLI response
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
}
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
|
||||
}
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
||||
appendEvent("message_start", string(messageStartTemplate))
|
||||
|
||||
params.HasFirstResponse = true
|
||||
}
|
||||
@@ -144,15 +150,13 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
params.CurrentThinkingText.Reset()
|
||||
}
|
||||
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex)), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.HasContent = true
|
||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||
params.CurrentThinkingText.WriteString(partTextResult.String())
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.HasContent = true
|
||||
} else {
|
||||
// Transition from another state to thinking
|
||||
@@ -163,19 +167,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||
params.ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new thinking content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex))
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.ResponseType = 2 // Set state to thinking
|
||||
params.HasContent = true
|
||||
// Start accumulating thinking text for signature caching
|
||||
@@ -188,9 +187,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Process regular text content (user-visible output)
|
||||
// Continue existing text block if already in content state
|
||||
if params.ResponseType == 1 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.HasContent = true
|
||||
} else {
|
||||
// Transition from another state to text content
|
||||
@@ -201,19 +199,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||
params.ResponseIndex++
|
||||
}
|
||||
if partTextResult.String() != "" {
|
||||
// Start a new text content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex))
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.ResponseType = 1 // Set state to content
|
||||
params.HasContent = true
|
||||
}
|
||||
@@ -224,14 +217,12 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude Code API compatibility
|
||||
params.HasToolUse = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName(params.ToolNameMap, functionCallResult.Get("name").String())
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
if params.ResponseType == 3 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||
params.ResponseIndex++
|
||||
params.ResponseType = 0
|
||||
}
|
||||
@@ -245,26 +236,21 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
|
||||
// Close any other existing content block
|
||||
if params.ResponseType != 0 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||
params.ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new tool use content block
|
||||
// This creates the structure for a function call in Claude Code format
|
||||
output = output + "event: content_block_start\n"
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex))
|
||||
data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.SetBytes(data, "content_block.name", fcName)
|
||||
appendEvent("content_block_start", string(data))
|
||||
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
|
||||
appendEvent("content_block_delta", string(data))
|
||||
}
|
||||
params.ResponseType = 3
|
||||
params.HasContent = true
|
||||
@@ -296,10 +282,10 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
appendFinalEvents(params, &output, false)
|
||||
}
|
||||
|
||||
return []string{output}
|
||||
return [][]byte{output}
|
||||
}
|
||||
|
||||
func appendFinalEvents(params *Params, output *string, force bool) {
|
||||
func appendFinalEvents(params *Params, output *[]byte, force bool) {
|
||||
if params.HasSentFinalEvents {
|
||||
return
|
||||
}
|
||||
@@ -314,9 +300,7 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
||||
}
|
||||
|
||||
if params.ResponseType != 0 {
|
||||
*output = *output + "event: content_block_stop\n"
|
||||
*output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
*output = *output + "\n\n\n"
|
||||
*output = translatorcommon.AppendSSEEventString(*output, "content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex), 3)
|
||||
params.ResponseType = 0
|
||||
}
|
||||
|
||||
@@ -329,18 +313,16 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
||||
}
|
||||
}
|
||||
|
||||
*output = *output + "event: message_delta\n"
|
||||
*output = *output + "data: "
|
||||
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
|
||||
delta := []byte(fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens))
|
||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||
if params.CachedTokenCount > 0 {
|
||||
var err error
|
||||
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
||||
delta, err = sjson.SetBytes(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
*output = *output + delta + "\n\n\n"
|
||||
*output = translatorcommon.AppendSSEEventString(*output, "message_delta", string(delta), 3)
|
||||
|
||||
params.HasSentFinalEvents = true
|
||||
}
|
||||
@@ -369,9 +351,9 @@ func resolveStopReason(params *Params) string {
|
||||
// - param: A pointer to a parameter object for the conversion.
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Claude-compatible JSON response.
|
||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
_ = originalRequestRawJSON
|
||||
// - []byte: A Claude-compatible JSON response.
|
||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
@@ -388,15 +370,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
|
||||
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
||||
responseJSON := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "id", root.Get("response.responseId").String())
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.input_tokens", promptTokens)
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.output_tokens", outputTokens)
|
||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||
if cachedTokens > 0 {
|
||||
var err error
|
||||
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||
responseJSON, err = sjson.SetBytes(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||
}
|
||||
@@ -407,7 +389,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
if contentArrayInitialized {
|
||||
return
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
|
||||
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content", []byte("[]"))
|
||||
contentArrayInitialized = true
|
||||
}
|
||||
|
||||
@@ -423,9 +405,9 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
return
|
||||
}
|
||||
ensureContentArray()
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", textBuilder.String())
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
block := []byte(`{"type":"text","text":""}`)
|
||||
block, _ = sjson.SetBytes(block, "text", textBuilder.String())
|
||||
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
|
||||
textBuilder.Reset()
|
||||
}
|
||||
|
||||
@@ -434,12 +416,12 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
return
|
||||
}
|
||||
ensureContentArray()
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
block := []byte(`{"type":"thinking","thinking":""}`)
|
||||
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
||||
if thinkingSignature != "" {
|
||||
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||
block, _ = sjson.SetBytes(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
thinkingSignature = ""
|
||||
}
|
||||
@@ -473,18 +455,18 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
flushText()
|
||||
hasToolCall = true
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
|
||||
toolIDCounter++
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
|
||||
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(args.Raw))
|
||||
}
|
||||
|
||||
ensureContentArray()
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
|
||||
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", toolBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -508,17 +490,17 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
}
|
||||
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "stop_reason", stopReason)
|
||||
|
||||
if promptTokens == 0 && outputTokens == 0 {
|
||||
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
||||
responseJSON, _ = sjson.Delete(responseJSON, "usage")
|
||||
responseJSON, _ = sjson.DeleteBytes(responseJSON, "usage")
|
||||
}
|
||||
}
|
||||
|
||||
return responseJSON
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.ClaudeInputTokensJSON(count)
|
||||
}
|
||||
|
||||
@@ -34,10 +34,10 @@ import (
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template := `{"project":"","request":{},"model":""}`
|
||||
templateBytes, _ := sjson.SetRawBytes([]byte(template), "request", rawJSON)
|
||||
templateBytes, _ = sjson.SetBytes(templateBytes, "model", modelName)
|
||||
template = string(templateBytes)
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
|
||||
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
||||
@@ -47,7 +47,8 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||
templateBytes, _ = sjson.SetRawBytes([]byte(template), "request.systemInstruction", []byte(systemInstructionResult.Raw))
|
||||
template = string(templateBytes)
|
||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||
}
|
||||
rawJSON = []byte(template)
|
||||
@@ -149,7 +150,8 @@ func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string
|
||||
raw := response.Raw
|
||||
name := response.Get("functionResponse.name").String()
|
||||
if strings.TrimSpace(name) == "" && fallbackName != "" {
|
||||
raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName)
|
||||
updated, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
|
||||
raw = string(updated)
|
||||
}
|
||||
return raw
|
||||
}
|
||||
@@ -157,27 +159,27 @@ func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string
|
||||
log.Debugf("parse function response failed, using fallback")
|
||||
funcResp := response.Get("functionResponse")
|
||||
if funcResp.Exists() {
|
||||
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
name := funcResp.Get("name").String()
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = fallbackName
|
||||
}
|
||||
fr, _ = sjson.Set(fr, "functionResponse.name", name)
|
||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.name", name)
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
||||
if id := funcResp.Get("id").String(); id != "" {
|
||||
fr, _ = sjson.Set(fr, "functionResponse.id", id)
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.id", id)
|
||||
}
|
||||
return fr
|
||||
return string(fr)
|
||||
}
|
||||
|
||||
useName := fallbackName
|
||||
if useName == "" {
|
||||
useName = "unknown"
|
||||
}
|
||||
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
fr, _ = sjson.Set(fr, "functionResponse.name", useName)
|
||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
|
||||
return fr
|
||||
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.name", useName)
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", response.String())
|
||||
return string(fr)
|
||||
}
|
||||
|
||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
@@ -204,7 +206,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
}
|
||||
|
||||
// Initialize data structures for processing and grouping
|
||||
contentsWrapper := `{"contents":[]}`
|
||||
contentsWrapper := []byte(`{"contents":[]}`)
|
||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||
|
||||
@@ -237,16 +239,16 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
||||
for ri, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
|
||||
}
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,7 +271,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
log.Warnf("failed to parse model content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
@@ -283,7 +285,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
}
|
||||
} else {
|
||||
// Non-model content (user, etc.)
|
||||
@@ -291,7 +293,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -303,23 +305,22 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
||||
for ri, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
|
||||
}
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the original JSON with the new contents
|
||||
result := input
|
||||
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
|
||||
result, _ := sjson.SetRawBytes([]byte(input), "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
|
||||
|
||||
return result, nil
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ package gemini
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -29,8 +29,8 @@ import (
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - []string: The transformed request data in Gemini API format
|
||||
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
|
||||
// - [][]byte: The transformed response data in Gemini API format.
|
||||
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
}
|
||||
@@ -44,22 +44,22 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
|
||||
chunk = restoreUsageMetadata(chunk)
|
||||
}
|
||||
} else {
|
||||
chunkTemplate := "[]"
|
||||
chunkTemplate := []byte("[]")
|
||||
responseResult := gjson.ParseBytes(chunk)
|
||||
if responseResult.IsArray() {
|
||||
responseResultItems := responseResult.Array()
|
||||
for i := 0; i < len(responseResultItems); i++ {
|
||||
responseResultItem := responseResultItems[i]
|
||||
if responseResultItem.Get("response").Exists() {
|
||||
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
||||
chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
chunk = []byte(chunkTemplate)
|
||||
chunk = chunkTemplate
|
||||
}
|
||||
return []string{string(chunk)}
|
||||
return [][]byte{chunk}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
|
||||
@@ -73,18 +73,18 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response containing the response data
|
||||
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Gemini-compatible JSON response containing the response data.
|
||||
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
chunk := restoreUsageMetadata([]byte(responseResult.Raw))
|
||||
return string(chunk)
|
||||
return chunk
|
||||
}
|
||||
return string(rawJSON)
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.
|
||||
|
||||
@@ -59,8 +59,8 @@ func TestConvertAntigravityResponseToGeminiNonStream(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", result, tt.expected)
|
||||
if string(result) != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", string(result), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -87,8 +87,8 @@ func TestConvertAntigravityResponseToGeminiStream(t *testing.T) {
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
if results[0] != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", results[0], tt.expected)
|
||||
if string(results[0]) != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", string(results[0]), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
@@ -309,7 +309,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
@@ -354,33 +354,39 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if errRename != nil {
|
||||
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
|
||||
var errSet error
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
||||
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
fnRaw = string(fnRawBytes)
|
||||
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw = string(fnRawBytes)
|
||||
} else {
|
||||
fnRaw = renamed
|
||||
}
|
||||
} else {
|
||||
var errSet error
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
||||
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
fnRaw = string(fnRawBytes)
|
||||
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw = string(fnRawBytes)
|
||||
}
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
fnRawBytes := []byte(fnRaw)
|
||||
fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String()))
|
||||
fnRaw, _ = sjson.Delete(string(fnRawBytes), "strict")
|
||||
if !hasFunction {
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
@@ -26,6 +27,7 @@ type convertCliResponseToOpenAIChatParams struct {
|
||||
FunctionIndex int
|
||||
SawToolCall bool // Tracks if any tool call was seen in the entire stream
|
||||
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -44,25 +46,29 @@ var functionCallIDCounter uint64
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of OpenAI-compatible JSON responses
|
||||
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &convertCliResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
template, _ = sjson.SetBytes(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
@@ -71,14 +77,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
if err == nil {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
|
||||
@@ -90,21 +96,21 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
@@ -141,33 +147,33 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
|
||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
|
||||
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
||||
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
||||
functionCallIndex = len(toolCallsResult.Array())
|
||||
} else {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
}
|
||||
|
||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||
functionCallTemplate := []byte(`{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`)
|
||||
fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||
} else if inlineDataResult.Exists() {
|
||||
data := inlineDataResult.Get("data").String()
|
||||
if data == "" {
|
||||
@@ -181,16 +187,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
imagesResult := gjson.GetBytes(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`))
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array())
|
||||
imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`)
|
||||
imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -212,11 +218,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
} else {
|
||||
finishReason = "stop"
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
|
||||
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
|
||||
@@ -231,11 +237,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
// - param: A pointer to a parameter object for the conversion
|
||||
//
|
||||
// Returns:
|
||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
|
||||
}
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
||||
if len(result1) != 1 {
|
||||
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
|
||||
}
|
||||
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
|
||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
|
||||
}
|
||||
@@ -33,13 +33,13 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
||||
if len(result2) != 1 {
|
||||
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
|
||||
}
|
||||
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||
if fr2 != "tool_calls" {
|
||||
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
|
||||
}
|
||||
|
||||
// Verify native_finish_reason is lowercase upstream value
|
||||
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
|
||||
nfr2 := gjson.GetBytes(result2[0], "choices.0.native_finish_reason").String()
|
||||
if nfr2 != "stop" {
|
||||
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func TestFinishReasonStopForNormalText(t *testing.T) {
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "stop" (no tool calls were made)
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "stop" {
|
||||
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func TestFinishReasonMaxTokens(t *testing.T) {
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "max_tokens"
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "max_tokens" {
|
||||
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "tool_calls" {
|
||||
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
|
||||
}
|
||||
@@ -111,7 +111,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
||||
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
|
||||
// Verify no finish_reason on intermediate chunk
|
||||
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
|
||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify no finish_reason on intermediate chunk
|
||||
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
|
||||
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason")
|
||||
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
|
||||
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
rawJSON = []byte(responseResult.Raw)
|
||||
@@ -15,7 +15,7 @@ func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName
|
||||
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
}
|
||||
|
||||
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
rawJSON = []byte(responseResult.Raw)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"context"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
|
||||
"github.com/tidwall/sjson"
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
)
|
||||
|
||||
// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format.
|
||||
@@ -23,15 +23,13 @@ import (
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
|
||||
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
// Wrap each converted response in a "response" object to match Gemini CLI API structure
|
||||
newOutputs := make([]string, 0)
|
||||
newOutputs := make([][]byte, 0, len(outputs))
|
||||
for i := 0; i < len(outputs); i++ {
|
||||
json := `{"response": {}}`
|
||||
output, _ := sjson.SetRaw(json, "response", outputs[i])
|
||||
newOutputs = append(newOutputs, output)
|
||||
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
|
||||
}
|
||||
return newOutputs
|
||||
}
|
||||
@@ -47,15 +45,13 @@ func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, ori
|
||||
// - param: A pointer to a parameter object for the conversion
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
strJSON := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
// - []byte: A Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
out := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
// Wrap the converted response in a "response" object to match Gemini CLI API structure
|
||||
json := `{"response": {}}`
|
||||
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
|
||||
return strJSON
|
||||
return translatorcommon.WrapGeminiCLIResponse(out)
|
||||
}
|
||||
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) string {
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
|
||||
return GeminiTokenCount(ctx, count)
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
|
||||
|
||||
// Base Claude message payload
|
||||
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
|
||||
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -87,20 +87,20 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
var pendingToolIDs []string
|
||||
|
||||
// Model mapping to specify which Claude Code model to use
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Generation config extraction from Gemini format
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
// Max output tokens configuration
|
||||
if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
|
||||
}
|
||||
// Temperature setting for controlling response randomness
|
||||
if temp := genConfig.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
out, _ = sjson.SetBytes(out, "temperature", temp.Float())
|
||||
} else if topP := genConfig.Get("topP"); topP.Exists() {
|
||||
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
|
||||
}
|
||||
// Stop sequences configuration for custom termination conditions
|
||||
if stopSeqs := genConfig.Get("stopSequences"); stopSeqs.Exists() && stopSeqs.IsArray() {
|
||||
@@ -110,7 +110,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
return true
|
||||
})
|
||||
if len(stopSequences) > 0 {
|
||||
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
|
||||
out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
|
||||
}
|
||||
}
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
@@ -132,30 +132,30 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
switch level {
|
||||
case "":
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
default:
|
||||
if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok {
|
||||
level = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", level)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "output_config.effort", level)
|
||||
}
|
||||
} else {
|
||||
switch level {
|
||||
case "":
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
default:
|
||||
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -169,37 +169,37 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if supportsAdaptive {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
default:
|
||||
level, ok := thinking.ConvertBudgetToLevel(budget)
|
||||
if ok {
|
||||
if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM {
|
||||
level = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", level)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "output_config.effort", level)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -220,9 +220,9 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
})
|
||||
if systemText.Len() > 0 {
|
||||
// Create system message in Claude Code format
|
||||
systemMessage := `{"role":"user","content":[{"type":"text","text":""}]}`
|
||||
systemMessage, _ = sjson.Set(systemMessage, "content.0.text", systemText.String())
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", systemMessage)
|
||||
systemMessage := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
|
||||
systemMessage, _ = sjson.SetBytes(systemMessage, "content.0.text", systemText.String())
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", systemMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,42 +245,42 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Create message structure in Claude Code format
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg := []byte(`{"role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
|
||||
if parts := content.Get("parts"); parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// Text content conversion
|
||||
if text := part.Get("text"); text.Exists() {
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent, _ = sjson.Set(textContent, "text", text.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
|
||||
textContent := []byte(`{"type":"text","text":""}`)
|
||||
textContent, _ = sjson.SetBytes(textContent, "text", text.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
|
||||
return true
|
||||
}
|
||||
|
||||
// Function call (from model/assistant) conversion to tool use
|
||||
if fc := part.Get("functionCall"); fc.Exists() && role == "assistant" {
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
|
||||
// Generate a unique tool ID and enqueue it for later matching
|
||||
// with the corresponding functionResponse
|
||||
toolID := genToolCallID()
|
||||
pendingToolIDs = append(pendingToolIDs, toolID)
|
||||
toolUse, _ = sjson.Set(toolUse, "id", toolID)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "id", toolID)
|
||||
|
||||
if name := fc.Get("name"); name.Exists() {
|
||||
toolUse, _ = sjson.Set(toolUse, "name", name.String())
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "name", name.String())
|
||||
}
|
||||
if args := fc.Get("args"); args.Exists() && args.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(args.Raw))
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
|
||||
return true
|
||||
}
|
||||
|
||||
// Function response (from user) conversion to tool result
|
||||
if fr := part.Get("functionResponse"); fr.Exists() {
|
||||
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
|
||||
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
|
||||
|
||||
// Attach the oldest queued tool_id to pair the response
|
||||
// with its call. If the queue is empty, generate a new id.
|
||||
@@ -293,41 +293,41 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Fallback: generate new ID if no pending tool_use found
|
||||
toolID = genToolCallID()
|
||||
}
|
||||
toolResult, _ = sjson.Set(toolResult, "tool_use_id", toolID)
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", toolID)
|
||||
|
||||
// Extract result content from the function response
|
||||
if result := fr.Get("response.result"); result.Exists() {
|
||||
toolResult, _ = sjson.Set(toolResult, "content", result.String())
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "content", result.String())
|
||||
} else if response := fr.Get("response"); response.Exists() {
|
||||
toolResult, _ = sjson.Set(toolResult, "content", response.Raw)
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "content", response.Raw)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", toolResult)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolResult)
|
||||
return true
|
||||
}
|
||||
|
||||
// Image content (inline_data) conversion to Claude Code format
|
||||
if inlineData := part.Get("inline_data"); inlineData.Exists() {
|
||||
imageContent := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imageContent := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
if mimeType := inlineData.Get("mime_type"); mimeType.Exists() {
|
||||
imageContent, _ = sjson.Set(imageContent, "source.media_type", mimeType.String())
|
||||
imageContent, _ = sjson.SetBytes(imageContent, "source.media_type", mimeType.String())
|
||||
}
|
||||
if data := inlineData.Get("data"); data.Exists() {
|
||||
imageContent, _ = sjson.Set(imageContent, "source.data", data.String())
|
||||
imageContent, _ = sjson.SetBytes(imageContent, "source.data", data.String())
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imageContent)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", imageContent)
|
||||
return true
|
||||
}
|
||||
|
||||
// File data conversion to text content with file info
|
||||
if fileData := part.Get("file_data"); fileData.Exists() {
|
||||
// For file data, we'll convert to text content with file info
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent := []byte(`{"type":"text","text":""}`)
|
||||
fileInfo := "File: " + fileData.Get("file_uri").String()
|
||||
if mimeType := fileData.Get("mime_type"); mimeType.Exists() {
|
||||
fileInfo += " (Type: " + mimeType.String() + ")"
|
||||
}
|
||||
textContent, _ = sjson.Set(textContent, "text", fileInfo)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
|
||||
textContent, _ = sjson.SetBytes(textContent, "text", fileInfo)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -336,8 +336,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Only add message if it has content
|
||||
if contentArray := gjson.Get(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
if contentArray := gjson.GetBytes(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -351,29 +351,29 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
if funcDecls := tool.Get("functionDeclarations"); funcDecls.Exists() && funcDecls.IsArray() {
|
||||
funcDecls.ForEach(func(_, funcDecl gjson.Result) bool {
|
||||
anthropicTool := `{"name":"","description":"","input_schema":{}}`
|
||||
anthropicTool := []byte(`{"name":"","description":"","input_schema":{}}`)
|
||||
|
||||
if name := funcDecl.Get("name"); name.Exists() {
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "name", name.String())
|
||||
anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", name.String())
|
||||
}
|
||||
if desc := funcDecl.Get("description"); desc.Exists() {
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "description", desc.String())
|
||||
anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", desc.String())
|
||||
}
|
||||
if params := funcDecl.Get("parameters"); params.Exists() {
|
||||
// Clean up the parameters schema for Claude Code compatibility
|
||||
cleaned := params.Raw
|
||||
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
|
||||
cleaned := []byte(params.Raw)
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
|
||||
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
|
||||
} else if params = funcDecl.Get("parametersJsonSchema"); params.Exists() {
|
||||
// Clean up the parameters schema for Claude Code compatibility
|
||||
cleaned := params.Raw
|
||||
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
|
||||
cleaned := []byte(params.Raw)
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
|
||||
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
|
||||
}
|
||||
|
||||
anthropicTools = append(anthropicTools, gjson.Parse(anthropicTool).Value())
|
||||
anthropicTools = append(anthropicTools, gjson.ParseBytes(anthropicTool).Value())
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -381,7 +381,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
})
|
||||
|
||||
if len(anthropicTools) > 0 {
|
||||
out, _ = sjson.Set(out, "tools", anthropicTools)
|
||||
out, _ = sjson.SetBytes(out, "tools", anthropicTools)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,27 +391,27 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if mode := funcCalling.Get("mode"); mode.Exists() {
|
||||
switch mode.String() {
|
||||
case "AUTO":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
|
||||
case "NONE":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"none"}`))
|
||||
case "ANY":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream setting configuration
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
out, _ = sjson.SetBytes(out, "stream", stream)
|
||||
|
||||
// Convert tool parameter types to lowercase for Claude Code compatibility
|
||||
var pathsToLower []string
|
||||
toolsResult := gjson.Get(out, "tools")
|
||||
toolsResult := gjson.GetBytes(out, "tools")
|
||||
util.Walk(toolsResult, "", "type", &pathsToLower)
|
||||
for _, p := range pathsToLower {
|
||||
fullPath := fmt.Sprintf("tools.%s", p)
|
||||
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
|
||||
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -30,7 +30,7 @@ type ConvertAnthropicResponseToGeminiParams struct {
|
||||
Model string
|
||||
CreatedAt int64
|
||||
ResponseID string
|
||||
LastStorageOutput string
|
||||
LastStorageOutput []byte
|
||||
IsStreaming bool
|
||||
|
||||
// Streaming state for tool_use assembly
|
||||
@@ -52,8 +52,8 @@ type ConvertAnthropicResponseToGeminiParams struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
|
||||
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Gemini-compatible JSON responses
|
||||
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertAnthropicResponseToGeminiParams{
|
||||
Model: modelName,
|
||||
@@ -63,7 +63,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
@@ -71,24 +71,24 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
eventType := root.Get("type").String()
|
||||
|
||||
// Base Gemini response template with default values
|
||||
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
|
||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
|
||||
|
||||
// Set model version
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).Model != "" {
|
||||
// Map Claude model names back to Gemini model names
|
||||
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
|
||||
}
|
||||
|
||||
// Set response ID and creation time
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID != "" {
|
||||
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
|
||||
}
|
||||
|
||||
// Set creation time to current time if not provided
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt == 0 {
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt = time.Now().Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
@@ -97,7 +97,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID = message.Get("id").String()
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).Model = message.Get("model").String()
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "content_block_start":
|
||||
// Start of a content block - record tool_use name by index for functionCall assembly
|
||||
@@ -112,7 +112,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "content_block_delta":
|
||||
// Handle content delta (text, thinking, or tool use arguments)
|
||||
@@ -123,16 +123,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
case "text_delta":
|
||||
// Regular text content delta for normal response text
|
||||
if text := delta.Get("text"); text.Exists() && text.String() != "" {
|
||||
textPart := `{"text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", text.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", textPart)
|
||||
textPart := []byte(`{"text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", text.String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", textPart)
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Thinking/reasoning content delta for models with reasoning capabilities
|
||||
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
|
||||
thinkingPart := `{"thought":true,"text":""}`
|
||||
thinkingPart, _ = sjson.Set(thinkingPart, "text", text.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", thinkingPart)
|
||||
thinkingPart := []byte(`{"thought":true,"text":""}`)
|
||||
thinkingPart, _ = sjson.SetBytes(thinkingPart, "text", text.String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", thinkingPart)
|
||||
}
|
||||
case "input_json_delta":
|
||||
// Tool use input delta - accumulate partial_json by index for later assembly at content_block_stop
|
||||
@@ -149,10 +149,10 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
if pj := delta.Get("partial_json"); pj.Exists() {
|
||||
b.WriteString(pj.String())
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
|
||||
case "content_block_stop":
|
||||
// End of content block - finalize tool calls if any
|
||||
@@ -170,16 +170,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
}
|
||||
if name != "" || argsTrim != "" {
|
||||
functionCall := `{"functionCall":{"name":"","args":{}}}`
|
||||
functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
|
||||
if name != "" {
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
|
||||
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", name)
|
||||
}
|
||||
if argsTrim != "" {
|
||||
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsTrim)
|
||||
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsTrim))
|
||||
}
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = template
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
|
||||
// cleanup used state for this index
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs != nil {
|
||||
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs, idx)
|
||||
@@ -187,9 +187,9 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil {
|
||||
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames, idx)
|
||||
}
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "message_delta":
|
||||
// Handle message-level changes (like stop reason and usage information)
|
||||
@@ -197,15 +197,15 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
|
||||
switch stopReason.String() {
|
||||
case "end_turn":
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
case "tool_use":
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
case "max_tokens":
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "MAX_TOKENS")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "MAX_TOKENS")
|
||||
case "stop_sequence":
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
default:
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,35 +216,35 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
|
||||
// Set basic usage metadata according to Gemini API specification
|
||||
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
|
||||
|
||||
// Add cache-related token counts if present (Claude Code API cache fields)
|
||||
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
|
||||
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
|
||||
}
|
||||
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
|
||||
// Add cache read tokens to cached content count
|
||||
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
|
||||
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
|
||||
}
|
||||
|
||||
// Add thinking tokens if present (for models with reasoning capabilities)
|
||||
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
|
||||
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
|
||||
}
|
||||
|
||||
// Set traffic type (required by Gemini API)
|
||||
template, _ = sjson.Set(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
|
||||
}
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
case "message_stop":
|
||||
// Final message with usage information - no additional output needed
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
case "error":
|
||||
// Handle error responses and convert to Gemini error format
|
||||
errorMsg := root.Get("error.message").String()
|
||||
@@ -253,13 +253,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
|
||||
// Create error response in Gemini format
|
||||
errorResponse := `{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`
|
||||
errorResponse, _ = sjson.Set(errorResponse, "error.message", errorMsg)
|
||||
return []string{errorResponse}
|
||||
errorResponse := []byte(`{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`)
|
||||
errorResponse, _ = sjson.SetBytes(errorResponse, "error.message", errorMsg)
|
||||
return [][]byte{errorResponse}
|
||||
|
||||
default:
|
||||
// Unknown event type, return empty response
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,13 +275,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response containing all message content and metadata
|
||||
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Gemini-compatible JSON response containing all message content and metadata
|
||||
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
// Base Gemini response template for non-streaming with default values
|
||||
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
|
||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
|
||||
|
||||
// Set model version
|
||||
template, _ = sjson.Set(template, "modelVersion", modelName)
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", modelName)
|
||||
|
||||
streamingEvents := make([][]byte, 0)
|
||||
|
||||
@@ -304,15 +304,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
LastStorageOutput: "",
|
||||
LastStorageOutput: nil,
|
||||
IsStreaming: false,
|
||||
ToolUseNames: nil,
|
||||
ToolUseArgs: nil,
|
||||
}
|
||||
|
||||
// Process each streaming event and collect parts
|
||||
var allParts []string
|
||||
var finalUsageJSON string
|
||||
var allParts [][]byte
|
||||
var finalUsageJSON []byte
|
||||
var responseID string
|
||||
var createdAt int64
|
||||
|
||||
@@ -360,15 +360,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
case "text_delta":
|
||||
// Process regular text content
|
||||
if text := delta.Get("text"); text.Exists() && text.String() != "" {
|
||||
partJSON := `{"text":""}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", text.String())
|
||||
partJSON := []byte(`{"text":""}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
|
||||
allParts = append(allParts, partJSON)
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Process reasoning/thinking content
|
||||
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
|
||||
partJSON := `{"thought":true,"text":""}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", text.String())
|
||||
partJSON := []byte(`{"thought":true,"text":""}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
|
||||
allParts = append(allParts, partJSON)
|
||||
}
|
||||
case "input_json_delta":
|
||||
@@ -402,12 +402,12 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
}
|
||||
}
|
||||
if name != "" || argsTrim != "" {
|
||||
functionCallJSON := `{"functionCall":{"name":"","args":{}}}`
|
||||
functionCallJSON := []byte(`{"functionCall":{"name":"","args":{}}}`)
|
||||
if name != "" {
|
||||
functionCallJSON, _ = sjson.Set(functionCallJSON, "functionCall.name", name)
|
||||
functionCallJSON, _ = sjson.SetBytes(functionCallJSON, "functionCall.name", name)
|
||||
}
|
||||
if argsTrim != "" {
|
||||
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
|
||||
functionCallJSON, _ = sjson.SetRawBytes(functionCallJSON, "functionCall.args", []byte(argsTrim))
|
||||
}
|
||||
allParts = append(allParts, functionCallJSON)
|
||||
// cleanup used state for this index
|
||||
@@ -422,35 +422,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
case "message_delta":
|
||||
// Extract final usage information using sjson for token counts and metadata
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
usageJSON := `{}`
|
||||
usageJSON := []byte(`{}`)
|
||||
|
||||
// Basic token counts for prompt and completion
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
|
||||
// Set basic usage metadata according to Gemini API specification
|
||||
usageJSON, _ = sjson.Set(usageJSON, "promptTokenCount", inputTokens)
|
||||
usageJSON, _ = sjson.Set(usageJSON, "candidatesTokenCount", outputTokens)
|
||||
usageJSON, _ = sjson.Set(usageJSON, "totalTokenCount", inputTokens+outputTokens)
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "promptTokenCount", inputTokens)
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "candidatesTokenCount", outputTokens)
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "totalTokenCount", inputTokens+outputTokens)
|
||||
|
||||
// Add cache-related token counts if present (Claude Code API cache fields)
|
||||
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
|
||||
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
|
||||
}
|
||||
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
|
||||
// Add cache read tokens to cached content count
|
||||
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
|
||||
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", totalCacheTokens)
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", totalCacheTokens)
|
||||
}
|
||||
|
||||
// Add thinking tokens if present (for models with reasoning capabilities)
|
||||
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
|
||||
usageJSON, _ = sjson.Set(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
|
||||
}
|
||||
|
||||
// Set traffic type (required by Gemini API)
|
||||
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
|
||||
|
||||
finalUsageJSON = usageJSON
|
||||
}
|
||||
@@ -459,10 +459,10 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Set response metadata
|
||||
if responseID != "" {
|
||||
template, _ = sjson.Set(template, "responseId", responseID)
|
||||
template, _ = sjson.SetBytes(template, "responseId", responseID)
|
||||
}
|
||||
if createdAt > 0 {
|
||||
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
|
||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
|
||||
}
|
||||
|
||||
// Consolidate consecutive text parts and thinking parts for cleaner output
|
||||
@@ -470,35 +470,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Set the consolidated parts array
|
||||
if len(consolidatedParts) > 0 {
|
||||
partsJSON := "[]"
|
||||
partsJSON := []byte(`[]`)
|
||||
for _, partJSON := range consolidatedParts {
|
||||
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
|
||||
partsJSON, _ = sjson.SetRawBytes(partsJSON, "-1", partJSON)
|
||||
}
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts", partsJSON)
|
||||
}
|
||||
|
||||
// Set usage metadata
|
||||
if finalUsageJSON != "" {
|
||||
template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
|
||||
if len(finalUsageJSON) > 0 {
|
||||
template, _ = sjson.SetRawBytes(template, "usageMetadata", finalUsageJSON)
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
// consolidateParts merges consecutive text parts and thinking parts to create a cleaner response.
|
||||
// This function processes the parts array to combine adjacent text elements and thinking elements
|
||||
// into single consolidated parts, which results in a more readable and efficient response structure.
|
||||
// Tool calls and other non-text parts are preserved as separate elements.
|
||||
func consolidateParts(parts []string) []string {
|
||||
func consolidateParts(parts [][]byte) [][]byte {
|
||||
if len(parts) == 0 {
|
||||
return parts
|
||||
}
|
||||
|
||||
var consolidated []string
|
||||
var consolidated [][]byte
|
||||
var currentTextPart strings.Builder
|
||||
var currentThoughtPart strings.Builder
|
||||
var hasText, hasThought bool
|
||||
@@ -506,8 +506,8 @@ func consolidateParts(parts []string) []string {
|
||||
flushText := func() {
|
||||
// Flush accumulated text content to the consolidated parts array
|
||||
if hasText && currentTextPart.Len() > 0 {
|
||||
textPartJSON := `{"text":""}`
|
||||
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
|
||||
textPartJSON := []byte(`{"text":""}`)
|
||||
textPartJSON, _ = sjson.SetBytes(textPartJSON, "text", currentTextPart.String())
|
||||
consolidated = append(consolidated, textPartJSON)
|
||||
currentTextPart.Reset()
|
||||
hasText = false
|
||||
@@ -517,8 +517,8 @@ func consolidateParts(parts []string) []string {
|
||||
flushThought := func() {
|
||||
// Flush accumulated thinking content to the consolidated parts array
|
||||
if hasThought && currentThoughtPart.Len() > 0 {
|
||||
thoughtPartJSON := `{"thought":true,"text":""}`
|
||||
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
|
||||
thoughtPartJSON := []byte(`{"thought":true,"text":""}`)
|
||||
thoughtPartJSON, _ = sjson.SetBytes(thoughtPartJSON, "text", currentThoughtPart.String())
|
||||
consolidated = append(consolidated, thoughtPartJSON)
|
||||
currentThoughtPart.Reset()
|
||||
hasThought = false
|
||||
@@ -526,7 +526,7 @@ func consolidateParts(parts []string) []string {
|
||||
}
|
||||
|
||||
for _, partJSON := range parts {
|
||||
part := gjson.Parse(partJSON)
|
||||
part := gjson.ParseBytes(partJSON)
|
||||
if !part.Exists() || !part.IsObject() {
|
||||
// Flush any pending parts and add this non-text part
|
||||
flushText()
|
||||
|
||||
@@ -61,7 +61,7 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
|
||||
|
||||
// Base Claude Code API template with default max_tokens value
|
||||
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
|
||||
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -79,20 +79,20 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if supportsAdaptive {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
default:
|
||||
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
|
||||
effort = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", effort)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "output_config.effort", effort)
|
||||
}
|
||||
} else {
|
||||
// Legacy/manual thinking (budget_tokens).
|
||||
@@ -100,13 +100,13 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -128,19 +128,19 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Model mapping to specify which Claude Code model to use
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Max tokens configuration with fallback to default value
|
||||
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
|
||||
}
|
||||
|
||||
// Temperature setting for controlling response randomness
|
||||
if temp := root.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
out, _ = sjson.SetBytes(out, "temperature", temp.Float())
|
||||
} else if topP := root.Get("top_p"); topP.Exists() {
|
||||
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
|
||||
}
|
||||
|
||||
// Stop sequences configuration for custom termination conditions
|
||||
@@ -152,60 +152,53 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
return true
|
||||
})
|
||||
if len(stopSequences) > 0 {
|
||||
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
|
||||
out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_sequences", []string{stop.String()})
|
||||
out, _ = sjson.SetBytes(out, "stop_sequences", []string{stop.String()})
|
||||
}
|
||||
}
|
||||
|
||||
// Stream configuration to enable or disable streaming responses
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
out, _ = sjson.SetBytes(out, "stream", stream)
|
||||
|
||||
// Process messages and transform them to Claude Code format
|
||||
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
|
||||
messageIndex := 0
|
||||
systemMessageIndex := -1
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
role := message.Get("role").String()
|
||||
contentResult := message.Get("content")
|
||||
|
||||
switch role {
|
||||
case "system":
|
||||
if systemMessageIndex == -1 {
|
||||
systemMsg := `{"role":"user","content":[]}`
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", systemMsg)
|
||||
systemMessageIndex = messageIndex
|
||||
messageIndex++
|
||||
}
|
||||
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", contentResult.String())
|
||||
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
|
||||
textPart := []byte(`{"type":"text","text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", contentResult.String())
|
||||
out, _ = sjson.SetRawBytes(out, "system.-1", textPart)
|
||||
} else if contentResult.Exists() && contentResult.IsArray() {
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
|
||||
textPart := []byte(`{"type":"text","text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
|
||||
out, _ = sjson.SetRawBytes(out, "system.-1", textPart)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
case "user", "assistant":
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg := []byte(`{"role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
|
||||
// Handle content based on its type (string or array)
|
||||
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
|
||||
part := `{"type":"text","text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentResult.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
part := []byte(`{"type":"text","text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", contentResult.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
} else if contentResult.Exists() && contentResult.IsArray() {
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||
if claudePart != "" {
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(claudePart))
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -221,9 +214,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
function := toolCall.Get("function")
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
|
||||
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "id", toolCallID)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "name", function.Get("name").String())
|
||||
|
||||
// Parse arguments for the tool call
|
||||
if args := function.Get("arguments"); args.Exists() {
|
||||
@@ -231,24 +224,24 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
if argsJSON.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
|
||||
} else {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
|
||||
}
|
||||
} else {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
|
||||
}
|
||||
} else {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
|
||||
}
|
||||
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
messageIndex++
|
||||
|
||||
case "tool":
|
||||
@@ -256,19 +249,29 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
toolCallID := message.Get("tool_call_id").String()
|
||||
toolContentResult := message.Get("content")
|
||||
|
||||
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
|
||||
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
|
||||
msg := []byte(`{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "content.0.tool_use_id", toolCallID)
|
||||
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
|
||||
if toolResultContentRaw {
|
||||
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.0.content", []byte(toolResultContent))
|
||||
} else {
|
||||
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
|
||||
msg, _ = sjson.SetBytes(msg, "content.0.content", toolResultContent)
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
messageIndex++
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Preserve a minimal conversational turn for system-only inputs.
|
||||
// Claude payloads with top-level system instructions but no messages are risky for downstream validation.
|
||||
if messageIndex == 0 {
|
||||
system := gjson.GetBytes(out, "system")
|
||||
if system.Exists() && system.IsArray() && len(system.Array()) > 0 {
|
||||
fallbackMsg := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", fallbackMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tools mapping: OpenAI tools -> Claude Code tools
|
||||
@@ -277,25 +280,25 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
if tool.Get("type").String() == "function" {
|
||||
function := tool.Get("function")
|
||||
anthropicTool := `{"name":"","description":""}`
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
|
||||
anthropicTool := []byte(`{"name":"","description":""}`)
|
||||
anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", function.Get("name").String())
|
||||
anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", function.Get("description").String())
|
||||
|
||||
// Convert parameters schema for the tool
|
||||
if parameters := function.Get("parameters"); parameters.Exists() {
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
|
||||
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
|
||||
} else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
|
||||
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
|
||||
out, _ = sjson.SetRawBytes(out, "tools.-1", anthropicTool)
|
||||
hasAnthropicTools = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !hasAnthropicTools {
|
||||
out, _ = sjson.Delete(out, "tools")
|
||||
out, _ = sjson.DeleteBytes(out, "tools")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,31 +311,31 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
case "none":
|
||||
// Don't set tool_choice, Claude Code will not use tools
|
||||
case "auto":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
|
||||
case "required":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
|
||||
}
|
||||
case gjson.JSON:
|
||||
// Specific tool choice mapping
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
functionName := toolChoice.Get("function.name").String()
|
||||
toolChoiceJSON := `{"type":"tool","name":""}`
|
||||
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
|
||||
toolChoiceJSON := []byte(`{"type":"tool","name":""}`)
|
||||
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", functionName)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
|
||||
switch part.Get("type").String() {
|
||||
case "text":
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
return textPart
|
||||
textPart := []byte(`{"type":"text","text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
|
||||
return string(textPart)
|
||||
|
||||
case "image_url":
|
||||
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
|
||||
@@ -345,10 +348,10 @@ func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
return docPart
|
||||
docPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
docPart, _ = sjson.SetBytes(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.SetBytes(docPart, "source.data", data)
|
||||
return string(docPart)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -373,15 +376,15 @@ func convertOpenAIImageURLToClaudePart(imageURL string) string {
|
||||
mediaType = "application/octet-stream"
|
||||
}
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
|
||||
return imagePart
|
||||
imagePart := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
imagePart, _ = sjson.SetBytes(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.SetBytes(imagePart, "source.data", parts[1])
|
||||
return string(imagePart)
|
||||
}
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
|
||||
return imagePart
|
||||
imagePart := []byte(`{"type":"image","source":{"type":"url","url":""}}`)
|
||||
imagePart, _ = sjson.SetBytes(imagePart, "source.url", imageURL)
|
||||
return string(imagePart)
|
||||
}
|
||||
|
||||
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
|
||||
@@ -394,28 +397,28 @@ func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
|
||||
}
|
||||
|
||||
if content.IsArray() {
|
||||
claudeContent := "[]"
|
||||
claudeContent := []byte("[]")
|
||||
partCount := 0
|
||||
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Type == gjson.String {
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.String())
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
|
||||
textPart := []byte(`{"type":"text","text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", part.String())
|
||||
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", textPart)
|
||||
partCount++
|
||||
return true
|
||||
}
|
||||
|
||||
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||
if claudePart != "" {
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
|
||||
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
|
||||
partCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if partCount > 0 || len(content.Array()) == 0 {
|
||||
return claudeContent, true
|
||||
return string(claudeContent), true
|
||||
}
|
||||
|
||||
return content.Raw, false
|
||||
@@ -424,9 +427,9 @@ func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
|
||||
if content.IsObject() {
|
||||
claudePart := convertOpenAIContentPartToClaudePart(content)
|
||||
if claudePart != "" {
|
||||
claudeContent := "[]"
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
|
||||
return claudeContent, true
|
||||
claudeContent := []byte("[]")
|
||||
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
|
||||
return string(claudeContent), true
|
||||
}
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
@@ -135,3 +135,111 @@ func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
|
||||
t.Fatalf("Unexpected image URL: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_SystemRoleBecomesTopLevelSystem(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
system := resultJSON.Get("system")
|
||||
if !system.IsArray() {
|
||||
t.Fatalf("Expected top-level system array, got %s", system.Raw)
|
||||
}
|
||||
if len(system.Array()) != 1 {
|
||||
t.Fatalf("Expected 1 system block, got %d. System: %s", len(system.Array()), system.Raw)
|
||||
}
|
||||
if got := system.Get("0.type").String(); got != "text" {
|
||||
t.Fatalf("Expected system block type %q, got %q", "text", got)
|
||||
}
|
||||
if got := system.Get("0.text").String(); got != "You are a helpful assistant." {
|
||||
t.Fatalf("Expected system text %q, got %q", "You are a helpful assistant.", got)
|
||||
}
|
||||
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("Expected 1 non-system message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
if got := messages[0].Get("role").String(); got != "user" {
|
||||
t.Fatalf("Expected remaining message role %q, got %q", "user", got)
|
||||
}
|
||||
if got := messages[0].Get("content.0.text").String(); got != "Hello" {
|
||||
t.Fatalf("Expected user text %q, got %q", "Hello", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_MultipleSystemMessagesMergedIntoTopLevelSystem(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Rule 1"},
|
||||
{"role": "system", "content": [{"type": "text", "text": "Rule 2"}]},
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
system := resultJSON.Get("system").Array()
|
||||
if len(system) != 2 {
|
||||
t.Fatalf("Expected 2 system blocks, got %d. System: %s", len(system), resultJSON.Get("system").Raw)
|
||||
}
|
||||
if got := system[0].Get("text").String(); got != "Rule 1" {
|
||||
t.Fatalf("Expected first system text %q, got %q", "Rule 1", got)
|
||||
}
|
||||
if got := system[1].Get("text").String(); got != "Rule 2" {
|
||||
t.Fatalf("Expected second system text %q, got %q", "Rule 2", got)
|
||||
}
|
||||
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("Expected 1 non-system message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
if got := messages[0].Get("role").String(); got != "user" {
|
||||
t.Fatalf("Expected remaining message role %q, got %q", "user", got)
|
||||
}
|
||||
if got := messages[0].Get("content.0.text").String(); got != "Hello" {
|
||||
t.Fatalf("Expected user text %q, got %q", "Hello", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_SystemOnlyInputKeepsFallbackUserMessage(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
system := resultJSON.Get("system").Array()
|
||||
if len(system) != 1 {
|
||||
t.Fatalf("Expected 1 system block, got %d. System: %s", len(system), resultJSON.Get("system").Raw)
|
||||
}
|
||||
if got := system[0].Get("text").String(); got != "You are a helpful assistant." {
|
||||
t.Fatalf("Expected system text %q, got %q", "You are a helpful assistant.", got)
|
||||
}
|
||||
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("Expected 1 fallback message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
if got := messages[0].Get("role").String(); got != "user" {
|
||||
t.Fatalf("Expected fallback message role %q, got %q", "user", got)
|
||||
}
|
||||
if got := messages[0].Get("content.0.type").String(); got != "text" {
|
||||
t.Fatalf("Expected fallback content type %q, got %q", "text", got)
|
||||
}
|
||||
if got := messages[0].Get("content.0.text").String(); got != "" {
|
||||
t.Fatalf("Expected fallback text %q, got %q", "", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,18 @@ type ToolCallAccumulator struct {
|
||||
Arguments strings.Builder
|
||||
}
|
||||
|
||||
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
completionTokens = usage.Get("output_tokens").Int()
|
||||
cachedTokens = usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
|
||||
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
|
||||
totalTokens = promptTokens + completionTokens
|
||||
|
||||
return promptTokens, completionTokens, totalTokens, cachedTokens
|
||||
}
|
||||
|
||||
// ConvertClaudeResponseToOpenAI converts Claude Code streaming response format to OpenAI Chat Completions format.
|
||||
// This function processes various Claude Code event types and transforms them into OpenAI-compatible JSON responses.
|
||||
// It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match
|
||||
@@ -48,8 +60,8 @@ type ToolCallAccumulator struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of OpenAI-compatible JSON responses
|
||||
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
var localParam any
|
||||
if param == nil {
|
||||
param = &localParam
|
||||
@@ -63,7 +75,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
@@ -71,19 +83,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
eventType := root.Get("type").String()
|
||||
|
||||
// Base OpenAI streaming response template
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`)
|
||||
|
||||
// Set model
|
||||
if modelName != "" {
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
}
|
||||
|
||||
// Set response ID and creation time
|
||||
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID != "" {
|
||||
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
|
||||
}
|
||||
if (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt > 0 {
|
||||
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
@@ -93,19 +105,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID = message.Get("id").String()
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt = time.Now().Unix()
|
||||
|
||||
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
|
||||
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
|
||||
|
||||
// Set initial role to assistant for the response
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
|
||||
// Initialize tool calls accumulator for tracking tool call progress
|
||||
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
|
||||
}
|
||||
}
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
|
||||
case "content_block_start":
|
||||
// Start of a content block (text, tool use, or reasoning)
|
||||
@@ -128,10 +140,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
}
|
||||
|
||||
// Don't output anything yet - wait for complete tool call
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "content_block_delta":
|
||||
// Handle content delta (text, tool use arguments, or reasoning content)
|
||||
@@ -143,13 +155,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
case "text_delta":
|
||||
// Text content delta - send incremental text updates
|
||||
if text := delta.Get("text"); text.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", text.String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.content", text.String())
|
||||
hasContent = true
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Accumulate reasoning/thinking content
|
||||
if thinking := delta.Get("thinking"); thinking.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", thinking.String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", thinking.String())
|
||||
hasContent = true
|
||||
}
|
||||
case "input_json_delta":
|
||||
@@ -163,13 +175,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
}
|
||||
}
|
||||
// Don't output anything yet - wait for complete tool call
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
if hasContent {
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
} else {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
@@ -182,63 +194,60 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
if arguments == "" {
|
||||
arguments = "{}"
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.index", index)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.type", "function")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
|
||||
|
||||
// Clean up the accumulator for this index
|
||||
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "message_delta":
|
||||
// Handle message-level changes including stop reason and usage
|
||||
if delta := root.Get("delta"); delta.Exists() {
|
||||
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason = mapAnthropicStopReasonToOpenAI(stopReason.String())
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle usage information for token counts
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||
}
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
|
||||
case "message_stop":
|
||||
// Final message event - no additional output needed
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "ping":
|
||||
// Ping events for keeping connection alive - no output needed
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "error":
|
||||
// Error event - format and return error response
|
||||
if errorData := root.Get("error"); errorData.Exists() {
|
||||
errorJSON := `{"error":{"message":"","type":""}}`
|
||||
errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
|
||||
errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
|
||||
return []string{errorJSON}
|
||||
errorJSON := []byte(`{"error":{"message":"","type":""}}`)
|
||||
errorJSON, _ = sjson.SetBytes(errorJSON, "error.message", errorData.Get("message").String())
|
||||
errorJSON, _ = sjson.SetBytes(errorJSON, "error.type", errorData.Get("type").String())
|
||||
return [][]byte{errorJSON}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
default:
|
||||
// Unknown event type - ignore
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,8 +279,8 @@ func mapAnthropicStopReasonToOpenAI(anthropicReason string) string {
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
chunks := make([][]byte, 0)
|
||||
|
||||
lines := bytes.Split(rawJSON, []byte("\n"))
|
||||
@@ -283,7 +292,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
|
||||
// Base OpenAI non-streaming response template
|
||||
out := `{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
out := []byte(`{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`)
|
||||
|
||||
var messageID string
|
||||
var model string
|
||||
@@ -366,32 +375,29 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
}
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set basic response fields including message ID, creation time, and model
|
||||
out, _ = sjson.Set(out, "id", messageID)
|
||||
out, _ = sjson.Set(out, "created", createdAt)
|
||||
out, _ = sjson.Set(out, "model", model)
|
||||
out, _ = sjson.SetBytes(out, "id", messageID)
|
||||
out, _ = sjson.SetBytes(out, "created", createdAt)
|
||||
out, _ = sjson.SetBytes(out, "model", model)
|
||||
|
||||
// Set message content by combining all text parts
|
||||
messageContent := strings.Join(contentParts, "")
|
||||
out, _ = sjson.Set(out, "choices.0.message.content", messageContent)
|
||||
out, _ = sjson.SetBytes(out, "choices.0.message.content", messageContent)
|
||||
|
||||
// Add reasoning content if available (following OpenAI reasoning format)
|
||||
if len(reasoningParts) > 0 {
|
||||
reasoningContent := strings.Join(reasoningParts, "")
|
||||
// Add reasoning as a separate field in the message
|
||||
out, _ = sjson.Set(out, "choices.0.message.reasoning", reasoningContent)
|
||||
out, _ = sjson.SetBytes(out, "choices.0.message.reasoning", reasoningContent)
|
||||
}
|
||||
|
||||
// Set tool calls if any were accumulated during processing
|
||||
@@ -417,19 +423,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
|
||||
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
|
||||
|
||||
out, _ = sjson.Set(out, idPath, accumulator.ID)
|
||||
out, _ = sjson.Set(out, typePath, "function")
|
||||
out, _ = sjson.Set(out, namePath, accumulator.Name)
|
||||
out, _ = sjson.Set(out, argumentsPath, arguments)
|
||||
out, _ = sjson.SetBytes(out, idPath, accumulator.ID)
|
||||
out, _ = sjson.SetBytes(out, typePath, "function")
|
||||
out, _ = sjson.SetBytes(out, namePath, accumulator.Name)
|
||||
out, _ = sjson.SetBytes(out, argumentsPath, arguments)
|
||||
toolCallsCount++
|
||||
}
|
||||
if toolCallsCount > 0 {
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
|
||||
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", "tool_calls")
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
out := ConvertClaudeResponseToOpenAI(
|
||||
ctx,
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
nil,
|
||||
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":13,"output_tokens":4,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}`),
|
||||
¶m,
|
||||
)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||
}
|
||||
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||
}
|
||||
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||
}
|
||||
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
|
||||
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
|
||||
|
||||
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
|
||||
|
||||
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||
}
|
||||
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||
}
|
||||
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||
}
|
||||
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
|
||||
|
||||
// Base Claude message payload
|
||||
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
|
||||
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -67,20 +67,20 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if supportsAdaptive {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
default:
|
||||
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
|
||||
effort = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", effort)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "output_config.effort", effort)
|
||||
}
|
||||
} else {
|
||||
// Legacy/manual thinking (budget_tokens).
|
||||
@@ -88,13 +88,13 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -114,15 +114,15 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
// Model
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Max tokens
|
||||
if mot := root.Get("max_output_tokens"); mot.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", mot.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tokens", mot.Int())
|
||||
}
|
||||
|
||||
// Stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
out, _ = sjson.SetBytes(out, "stream", stream)
|
||||
|
||||
// instructions -> as a leading message (use role user for Claude API compatibility)
|
||||
instructionsText := ""
|
||||
@@ -130,9 +130,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if instr := root.Get("instructions"); instr.Exists() && instr.Type == gjson.String {
|
||||
instructionsText = instr.String()
|
||||
if instructionsText != "" {
|
||||
sysMsg := `{"role":"user","content":""}`
|
||||
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
|
||||
sysMsg := []byte(`{"role":"user","content":""}`)
|
||||
sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,9 +156,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
}
|
||||
instructionsText = builder.String()
|
||||
if instructionsText != "" {
|
||||
sysMsg := `{"role":"user","content":""}`
|
||||
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
|
||||
sysMsg := []byte(`{"role":"user","content":""}`)
|
||||
sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
|
||||
extractedFromSystem = true
|
||||
}
|
||||
}
|
||||
@@ -193,9 +193,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if t := part.Get("text"); t.Exists() {
|
||||
txt := t.String()
|
||||
textAggregate.WriteString(txt)
|
||||
contentPart := `{"type":"text","text":""}`
|
||||
contentPart, _ = sjson.Set(contentPart, "text", txt)
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
contentPart := []byte(`{"type":"text","text":""}`)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "text", txt)
|
||||
partsJSON = append(partsJSON, string(contentPart))
|
||||
}
|
||||
if ptype == "input_text" {
|
||||
role = "user"
|
||||
@@ -208,7 +208,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
url = part.Get("url").String()
|
||||
}
|
||||
if url != "" {
|
||||
var contentPart string
|
||||
var contentPart []byte
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
trimmed := strings.TrimPrefix(url, "data:")
|
||||
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||
@@ -221,16 +221,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
data = mediaAndData[1]
|
||||
}
|
||||
if data != "" {
|
||||
contentPart = `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||
contentPart = []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
|
||||
}
|
||||
} else {
|
||||
contentPart = `{"type":"image","source":{"type":"url","url":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.url", url)
|
||||
contentPart = []byte(`{"type":"image","source":{"type":"url","url":""}}`)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.url", url)
|
||||
}
|
||||
if contentPart != "" {
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
if len(contentPart) > 0 {
|
||||
partsJSON = append(partsJSON, string(contentPart))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
@@ -252,10 +252,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
data = mediaAndData[1]
|
||||
}
|
||||
}
|
||||
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
contentPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
|
||||
partsJSON = append(partsJSON, string(contentPart))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
@@ -280,24 +280,24 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
if len(partsJSON) > 0 {
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg := []byte(`{"role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||
// Preserve legacy behavior for single text content
|
||||
msg, _ = sjson.Delete(msg, "content")
|
||||
msg, _ = sjson.DeleteBytes(msg, "content")
|
||||
textPart := gjson.Parse(partsJSON[0])
|
||||
msg, _ = sjson.Set(msg, "content", textPart.Get("text").String())
|
||||
msg, _ = sjson.SetBytes(msg, "content", textPart.Get("text").String())
|
||||
} else {
|
||||
for _, partJSON := range partsJSON {
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", partJSON)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(partJSON))
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
} else if textAggregate.Len() > 0 || role == "system" {
|
||||
msg := `{"role":"","content":""}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg, _ = sjson.Set(msg, "content", textAggregate.String())
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
msg := []byte(`{"role":"","content":""}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
msg, _ = sjson.SetBytes(msg, "content", textAggregate.String())
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
}
|
||||
|
||||
case "function_call":
|
||||
@@ -309,31 +309,31 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
name := item.Get("name").String()
|
||||
argsStr := item.Get("arguments").String()
|
||||
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse, _ = sjson.Set(toolUse, "id", callID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", name)
|
||||
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "id", callID)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "name", name)
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
if argsJSON.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
|
||||
}
|
||||
}
|
||||
|
||||
asst := `{"role":"assistant","content":[]}`
|
||||
asst, _ = sjson.SetRaw(asst, "content.-1", toolUse)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", asst)
|
||||
asst := []byte(`{"role":"assistant","content":[]}`)
|
||||
asst, _ = sjson.SetRawBytes(asst, "content.-1", toolUse)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", asst)
|
||||
|
||||
case "function_call_output":
|
||||
// Map to user tool_result
|
||||
callID := item.Get("call_id").String()
|
||||
outputStr := item.Get("output").String()
|
||||
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
|
||||
toolResult, _ = sjson.Set(toolResult, "tool_use_id", callID)
|
||||
toolResult, _ = sjson.Set(toolResult, "content", outputStr)
|
||||
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", callID)
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "content", outputStr)
|
||||
|
||||
usr := `{"role":"user","content":[]}`
|
||||
usr, _ = sjson.SetRaw(usr, "content.-1", toolResult)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", usr)
|
||||
usr := []byte(`{"role":"user","content":[]}`)
|
||||
usr, _ = sjson.SetRawBytes(usr, "content.-1", toolResult)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", usr)
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -341,27 +341,27 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
|
||||
// tools mapping: parameters -> input_schema
|
||||
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
|
||||
toolsJSON := "[]"
|
||||
toolsJSON := []byte("[]")
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
tJSON := `{"name":"","description":"","input_schema":{}}`
|
||||
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
|
||||
if n := tool.Get("name"); n.Exists() {
|
||||
tJSON, _ = sjson.Set(tJSON, "name", n.String())
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
|
||||
}
|
||||
if d := tool.Get("description"); d.Exists() {
|
||||
tJSON, _ = sjson.Set(tJSON, "description", d.String())
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
|
||||
}
|
||||
|
||||
if params := tool.Get("parameters"); params.Exists() {
|
||||
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
|
||||
} else if params = tool.Get("parametersJsonSchema"); params.Exists() {
|
||||
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
|
||||
}
|
||||
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "-1", tJSON)
|
||||
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
|
||||
return true
|
||||
})
|
||||
if gjson.Parse(toolsJSON).IsArray() && len(gjson.Parse(toolsJSON).Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "tools", toolsJSON)
|
||||
if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "tools", toolsJSON)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -371,23 +371,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
case gjson.String:
|
||||
switch toolChoice.String() {
|
||||
case "auto":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
|
||||
case "none":
|
||||
// Leave unset; implies no tools
|
||||
case "required":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
|
||||
}
|
||||
case gjson.JSON:
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
fn := toolChoice.Get("function.name").String()
|
||||
toolChoiceJSON := `{"name":"","type":"tool"}`
|
||||
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
|
||||
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
|
||||
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
|
||||
}
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -50,12 +51,12 @@ func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func emitEvent(event string, payload string) string {
|
||||
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
||||
func emitEvent(event string, payload []byte) []byte {
|
||||
return translatorcommon.SSEEventData(event, payload)
|
||||
}
|
||||
|
||||
// ConvertClaudeResponseToOpenAIResponses converts Claude SSE to OpenAI Responses SSE events.
|
||||
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &claudeToResponsesState{FuncArgsBuf: make(map[int]*strings.Builder), FuncNames: make(map[int]string), FuncCallIDs: make(map[int]string)}
|
||||
}
|
||||
@@ -63,12 +64,12 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
|
||||
// Expect `data: {..}` from Claude clients
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
ev := root.Get("type").String()
|
||||
var out []string
|
||||
var out [][]byte
|
||||
|
||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||
|
||||
@@ -105,16 +106,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
}
|
||||
// response.created
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
|
||||
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.SetBytes(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.SetBytes(created, "response.created_at", st.CreatedAt)
|
||||
out = append(out, emitEvent("response.created", created))
|
||||
// response.in_progress
|
||||
inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`
|
||||
inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq())
|
||||
inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID)
|
||||
inprog, _ = sjson.Set(inprog, "response.created_at", st.CreatedAt)
|
||||
inprog := []byte(`{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`)
|
||||
inprog, _ = sjson.SetBytes(inprog, "sequence_number", nextSeq())
|
||||
inprog, _ = sjson.SetBytes(inprog, "response.id", st.ResponseID)
|
||||
inprog, _ = sjson.SetBytes(inprog, "response.created_at", st.CreatedAt)
|
||||
out = append(out, emitEvent("response.in_progress", inprog))
|
||||
}
|
||||
case "content_block_start":
|
||||
@@ -128,25 +129,25 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
// open message item + content part
|
||||
st.InTextBlock = true
|
||||
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.Set(item, "item.id", st.CurrentMsgID)
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "item.id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
|
||||
part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
part, _ = sjson.Set(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.Set(part, "item_id", st.CurrentMsgID)
|
||||
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.SetBytes(part, "item_id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.content_part.added", part))
|
||||
} else if typ == "tool_use" {
|
||||
st.InFuncBlock = true
|
||||
st.CurrentFCID = cb.Get("id").String()
|
||||
name := cb.Get("name").String()
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.Set(item, "output_index", idx)
|
||||
item, _ = sjson.Set(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
item, _ = sjson.Set(item, "item.call_id", st.CurrentFCID)
|
||||
item, _ = sjson.Set(item, "item.name", name)
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
item, _ = sjson.SetBytes(item, "item.call_id", st.CurrentFCID)
|
||||
item, _ = sjson.SetBytes(item, "item.name", name)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
if st.FuncArgsBuf[idx] == nil {
|
||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||
@@ -160,16 +161,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
st.ReasoningIndex = idx
|
||||
st.ReasoningBuf.Reset()
|
||||
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.Set(item, "output_index", idx)
|
||||
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningItemID)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
// add a summary part placeholder
|
||||
part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
|
||||
part, _ = sjson.Set(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.Set(part, "item_id", st.ReasoningItemID)
|
||||
part, _ = sjson.Set(part, "output_index", idx)
|
||||
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.SetBytes(part, "item_id", st.ReasoningItemID)
|
||||
part, _ = sjson.SetBytes(part, "output_index", idx)
|
||||
out = append(out, emitEvent("response.reasoning_summary_part.added", part))
|
||||
st.ReasoningPartAdded = true
|
||||
}
|
||||
@@ -181,10 +182,10 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
dt := d.Get("type").String()
|
||||
if dt == "text_delta" {
|
||||
if t := d.Get("text"); t.Exists() {
|
||||
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", st.CurrentMsgID)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.output_text.delta", msg))
|
||||
// aggregate text for response.output
|
||||
st.TextBuf.WriteString(t.String())
|
||||
@@ -196,22 +197,22 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||
}
|
||||
st.FuncArgsBuf[idx].WriteString(pj.String())
|
||||
msg := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
msg, _ = sjson.Set(msg, "output_index", idx)
|
||||
msg, _ = sjson.Set(msg, "delta", pj.String())
|
||||
msg := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", idx)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", pj.String())
|
||||
out = append(out, emitEvent("response.function_call_arguments.delta", msg))
|
||||
}
|
||||
} else if dt == "thinking_delta" {
|
||||
if st.ReasoningActive {
|
||||
if t := d.Get("thinking"); t.Exists() {
|
||||
st.ReasoningBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
msg := []byte(`{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
}
|
||||
@@ -219,17 +220,17 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
case "content_block_stop":
|
||||
idx := int(root.Get("index").Int())
|
||||
if st.InTextBlock {
|
||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||
final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
|
||||
final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
st.InTextBlock = false
|
||||
} else if st.InFuncBlock {
|
||||
@@ -239,34 +240,34 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
args = buf.String()
|
||||
}
|
||||
}
|
||||
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
|
||||
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
|
||||
fcDone, _ = sjson.Set(fcDone, "arguments", args)
|
||||
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", idx)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
|
||||
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
|
||||
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
|
||||
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", st.CurrentFCID)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[idx])
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
st.InFuncBlock = false
|
||||
} else if st.ReasoningActive {
|
||||
full := st.ReasoningBuf.String()
|
||||
textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq())
|
||||
textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningItemID)
|
||||
textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex)
|
||||
textDone, _ = sjson.Set(textDone, "text", full)
|
||||
textDone := []byte(`{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`)
|
||||
textDone, _ = sjson.SetBytes(textDone, "sequence_number", nextSeq())
|
||||
textDone, _ = sjson.SetBytes(textDone, "item_id", st.ReasoningItemID)
|
||||
textDone, _ = sjson.SetBytes(textDone, "output_index", st.ReasoningIndex)
|
||||
textDone, _ = sjson.SetBytes(textDone, "text", full)
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.done", textDone))
|
||||
partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningItemID)
|
||||
partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex)
|
||||
partDone, _ = sjson.Set(partDone, "part.text", full)
|
||||
partDone := []byte(`{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", st.ReasoningItemID)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", st.ReasoningIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", full)
|
||||
out = append(out, emitEvent("response.reasoning_summary_part.done", partDone))
|
||||
st.ReasoningActive = false
|
||||
st.ReasoningPartAdded = false
|
||||
@@ -284,92 +285,92 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
case "message_stop":
|
||||
|
||||
completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`
|
||||
completed, _ = sjson.Set(completed, "sequence_number", nextSeq())
|
||||
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
|
||||
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
|
||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
|
||||
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
|
||||
completed, _ = sjson.SetBytes(completed, "response.created_at", st.CreatedAt)
|
||||
// Inject original request fields into response as per docs/response.completed.json
|
||||
|
||||
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
|
||||
if len(reqBytes) > 0 {
|
||||
req := gjson.ParseBytes(reqBytes)
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.instructions", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
|
||||
}
|
||||
if v := req.Get("max_output_tokens"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.max_output_tokens", v.Int())
|
||||
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
|
||||
}
|
||||
if v := req.Get("max_tool_calls"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.max_tool_calls", v.Int())
|
||||
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
|
||||
}
|
||||
if v := req.Get("model"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.model", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
|
||||
}
|
||||
if v := req.Get("parallel_tool_calls"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.parallel_tool_calls", v.Bool())
|
||||
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
|
||||
}
|
||||
if v := req.Get("previous_response_id"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.previous_response_id", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
|
||||
}
|
||||
if v := req.Get("prompt_cache_key"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.prompt_cache_key", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
|
||||
}
|
||||
if v := req.Get("reasoning"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.reasoning", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
|
||||
}
|
||||
if v := req.Get("safety_identifier"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.safety_identifier", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
|
||||
}
|
||||
if v := req.Get("service_tier"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.service_tier", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
|
||||
}
|
||||
if v := req.Get("store"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.store", v.Bool())
|
||||
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
|
||||
}
|
||||
if v := req.Get("temperature"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.temperature", v.Float())
|
||||
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
|
||||
}
|
||||
if v := req.Get("text"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.text", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
|
||||
}
|
||||
if v := req.Get("tool_choice"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.tool_choice", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
|
||||
}
|
||||
if v := req.Get("tools"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.tools", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
|
||||
}
|
||||
if v := req.Get("top_logprobs"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.top_logprobs", v.Int())
|
||||
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
|
||||
}
|
||||
if v := req.Get("top_p"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.top_p", v.Float())
|
||||
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
|
||||
}
|
||||
if v := req.Get("truncation"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.truncation", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
|
||||
}
|
||||
if v := req.Get("user"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.user", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
|
||||
}
|
||||
if v := req.Get("metadata"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.metadata", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
|
||||
}
|
||||
}
|
||||
|
||||
// Build response.output from aggregated state
|
||||
outputsWrapper := `{"arr":[]}`
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
// reasoning item (if any)
|
||||
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
|
||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
// assistant message item (if any text)
|
||||
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
// function_call items (in ascending index order for determinism)
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
@@ -396,16 +397,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
if callID == "" && st.CurrentFCID != "" {
|
||||
callID = st.CurrentFCID
|
||||
}
|
||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.Set(item, "arguments", args)
|
||||
item, _ = sjson.Set(item, "call_id", callID)
|
||||
item, _ = sjson.Set(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
}
|
||||
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
|
||||
reasoningTokens := int64(0)
|
||||
@@ -414,15 +415,15 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
usagePresent := st.UsageSeen || reasoningTokens > 0
|
||||
if usagePresent {
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.InputTokens)
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.OutputTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.InputTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", 0)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.OutputTokens)
|
||||
if reasoningTokens > 0 {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
}
|
||||
total := st.InputTokens + st.OutputTokens
|
||||
if total > 0 || st.UsageSeen {
|
||||
completed, _ = sjson.Set(completed, "response.usage.total_tokens", total)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
|
||||
}
|
||||
}
|
||||
out = append(out, emitEvent("response.completed", completed))
|
||||
@@ -432,7 +433,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
|
||||
// ConvertClaudeResponseToOpenAIResponsesNonStream aggregates Claude SSE into a single OpenAI Responses JSON.
|
||||
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
// Aggregate Claude SSE lines into a single OpenAI Responses JSON (non-stream)
|
||||
// We follow the same aggregation logic as the streaming variant but produce
|
||||
// one final object matching docs/out.json structure.
|
||||
@@ -455,7 +456,7 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
}
|
||||
|
||||
// Base OpenAI Responses (non-stream) object
|
||||
out := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`
|
||||
out := []byte(`{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`)
|
||||
|
||||
// Aggregation state
|
||||
var (
|
||||
@@ -557,88 +558,88 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
}
|
||||
|
||||
// Populate base fields
|
||||
out, _ = sjson.Set(out, "id", responseID)
|
||||
out, _ = sjson.Set(out, "created_at", createdAt)
|
||||
out, _ = sjson.SetBytes(out, "id", responseID)
|
||||
out, _ = sjson.SetBytes(out, "created_at", createdAt)
|
||||
|
||||
// Inject request echo fields as top-level (similar to streaming variant)
|
||||
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
|
||||
if len(reqBytes) > 0 {
|
||||
req := gjson.ParseBytes(reqBytes)
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "instructions", v.String())
|
||||
out, _ = sjson.SetBytes(out, "instructions", v.String())
|
||||
}
|
||||
if v := req.Get("max_output_tokens"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "max_output_tokens", v.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_output_tokens", v.Int())
|
||||
}
|
||||
if v := req.Get("max_tool_calls"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tool_calls", v.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tool_calls", v.Int())
|
||||
}
|
||||
if v := req.Get("model"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "model", v.String())
|
||||
out, _ = sjson.SetBytes(out, "model", v.String())
|
||||
}
|
||||
if v := req.Get("parallel_tool_calls"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", v.Bool())
|
||||
out, _ = sjson.SetBytes(out, "parallel_tool_calls", v.Bool())
|
||||
}
|
||||
if v := req.Get("previous_response_id"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "previous_response_id", v.String())
|
||||
out, _ = sjson.SetBytes(out, "previous_response_id", v.String())
|
||||
}
|
||||
if v := req.Get("prompt_cache_key"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "prompt_cache_key", v.String())
|
||||
out, _ = sjson.SetBytes(out, "prompt_cache_key", v.String())
|
||||
}
|
||||
if v := req.Get("reasoning"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "reasoning", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "reasoning", v.Value())
|
||||
}
|
||||
if v := req.Get("safety_identifier"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "safety_identifier", v.String())
|
||||
out, _ = sjson.SetBytes(out, "safety_identifier", v.String())
|
||||
}
|
||||
if v := req.Get("service_tier"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "service_tier", v.String())
|
||||
out, _ = sjson.SetBytes(out, "service_tier", v.String())
|
||||
}
|
||||
if v := req.Get("store"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "store", v.Bool())
|
||||
out, _ = sjson.SetBytes(out, "store", v.Bool())
|
||||
}
|
||||
if v := req.Get("temperature"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", v.Float())
|
||||
out, _ = sjson.SetBytes(out, "temperature", v.Float())
|
||||
}
|
||||
if v := req.Get("text"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "text", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text", v.Value())
|
||||
}
|
||||
if v := req.Get("tool_choice"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "tool_choice", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "tool_choice", v.Value())
|
||||
}
|
||||
if v := req.Get("tools"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "tools", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "tools", v.Value())
|
||||
}
|
||||
if v := req.Get("top_logprobs"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "top_logprobs", v.Int())
|
||||
out, _ = sjson.SetBytes(out, "top_logprobs", v.Int())
|
||||
}
|
||||
if v := req.Get("top_p"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "top_p", v.Float())
|
||||
out, _ = sjson.SetBytes(out, "top_p", v.Float())
|
||||
}
|
||||
if v := req.Get("truncation"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "truncation", v.String())
|
||||
out, _ = sjson.SetBytes(out, "truncation", v.String())
|
||||
}
|
||||
if v := req.Get("user"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "user", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "user", v.Value())
|
||||
}
|
||||
if v := req.Get("metadata"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "metadata", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "metadata", v.Value())
|
||||
}
|
||||
}
|
||||
|
||||
// Build output array
|
||||
outputsWrapper := `{"arr":[]}`
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
if reasoningBuf.Len() > 0 {
|
||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", reasoningItemID)
|
||||
item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", reasoningItemID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", reasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if currentMsgID != "" || textBuf.Len() > 0 {
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", currentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", textBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", currentMsgID)
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", textBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
// Preserve index order
|
||||
@@ -659,28 +660,28 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
|
||||
item, _ = sjson.Set(item, "arguments", args)
|
||||
item, _ = sjson.Set(item, "call_id", st.id)
|
||||
item, _ = sjson.Set(item, "name", st.name)
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", st.id))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", st.id)
|
||||
item, _ = sjson.SetBytes(item, "name", st.name)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
}
|
||||
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
|
||||
out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
|
||||
// Usage
|
||||
total := inputTokens + outputTokens
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.total_tokens", total)
|
||||
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", total)
|
||||
if reasoningBuf.Len() > 0 {
|
||||
// Rough estimate similar to chat completions
|
||||
reasoningTokens := int64(len(reasoningBuf.String()) / 4)
|
||||
if reasoningTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,15 +36,15 @@ import (
|
||||
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
template := `{"model":"","instructions":"","input":[]}`
|
||||
template := []byte(`{"model":"","instructions":"","input":[]}`)
|
||||
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
|
||||
// Process system messages and convert them to input content format.
|
||||
systemsResult := rootResult.Get("system")
|
||||
if systemsResult.Exists() {
|
||||
message := `{"type":"message","role":"developer","content":[]}`
|
||||
message := []byte(`{"type":"message","role":"developer","content":[]}`)
|
||||
contentIndex := 0
|
||||
|
||||
appendSystemText := func(text string) {
|
||||
@@ -52,8 +52,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
return
|
||||
}
|
||||
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
|
||||
if contentIndex > 0 {
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
template, _ = sjson.SetRawBytes(template, "input.-1", message)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,9 +83,9 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
messageResult := messageResults[i]
|
||||
messageRole := messageResult.Get("role").String()
|
||||
|
||||
newMessage := func() string {
|
||||
msg := `{"type": "message","role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", messageRole)
|
||||
newMessage := func() []byte {
|
||||
msg := []byte(`{"type":"message","role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", messageRole)
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
flushMessage := func() {
|
||||
if hasContent {
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
template, _ = sjson.SetRawBytes(template, "input.-1", message)
|
||||
message = newMessage()
|
||||
contentIndex = 0
|
||||
hasContent = false
|
||||
@@ -107,15 +107,15 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
if messageRole == "assistant" {
|
||||
partType = "output_text"
|
||||
}
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
hasContent = true
|
||||
}
|
||||
|
||||
appendImageContent := func(dataURL string) {
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
|
||||
contentIndex++
|
||||
hasContent = true
|
||||
}
|
||||
@@ -151,8 +151,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
case "tool_use":
|
||||
flushMessage()
|
||||
functionCallMessage := `{"type":"function_call"}`
|
||||
functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String())
|
||||
functionCallMessage := []byte(`{"type":"function_call"}`)
|
||||
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", messageContentResult.Get("id").String())
|
||||
{
|
||||
name := messageContentResult.Get("name").String()
|
||||
toolMap := buildReverseMapFromClaudeOriginalToShort(rawJSON)
|
||||
@@ -161,19 +161,19 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
functionCallMessage, _ = sjson.Set(functionCallMessage, "name", name)
|
||||
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "name", name)
|
||||
}
|
||||
functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
|
||||
template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage)
|
||||
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
|
||||
template, _ = sjson.SetRawBytes(template, "input.-1", functionCallMessage)
|
||||
case "tool_result":
|
||||
flushMessage()
|
||||
functionCallOutputMessage := `{"type":"function_call_output"}`
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
||||
functionCallOutputMessage := []byte(`{"type":"function_call_output"}`)
|
||||
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
||||
|
||||
contentResult := messageContentResult.Get("content")
|
||||
if contentResult.IsArray() {
|
||||
toolResultContentIndex := 0
|
||||
toolResultContent := `[]`
|
||||
toolResultContent := []byte(`[]`)
|
||||
contentResults := contentResult.Array()
|
||||
for k := 0; k < len(contentResults); k++ {
|
||||
toolResultContentType := contentResults[k].Get("type").String()
|
||||
@@ -194,27 +194,27 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
|
||||
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
|
||||
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
|
||||
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
|
||||
toolResultContentIndex++
|
||||
}
|
||||
}
|
||||
} else if toolResultContentType == "text" {
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
|
||||
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
|
||||
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
|
||||
toolResultContentIndex++
|
||||
}
|
||||
}
|
||||
if toolResultContent != `[]` {
|
||||
functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent)
|
||||
if toolResultContentIndex > 0 {
|
||||
functionCallOutputMessage, _ = sjson.SetRawBytes(functionCallOutputMessage, "output", toolResultContent)
|
||||
} else {
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
}
|
||||
} else {
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
}
|
||||
|
||||
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
|
||||
template, _ = sjson.SetRawBytes(template, "input.-1", functionCallOutputMessage)
|
||||
}
|
||||
}
|
||||
flushMessage()
|
||||
@@ -229,8 +229,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Convert tools declarations to the expected format for the Codex API.
|
||||
toolsResult := rootResult.Get("tools")
|
||||
if toolsResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "tools", `[]`)
|
||||
template, _ = sjson.Set(template, "tool_choice", `auto`)
|
||||
template, _ = sjson.SetRawBytes(template, "tools", []byte(`[]`))
|
||||
template, _ = sjson.SetBytes(template, "tool_choice", `auto`)
|
||||
toolResults := toolsResult.Array()
|
||||
// Build short name map from declared tools
|
||||
var names []string
|
||||
@@ -246,11 +246,11 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Special handling: map Claude web search tool to Codex web_search
|
||||
if toolResult.Get("type").String() == "web_search_20250305" {
|
||||
// Replace the tool content entirely with {"type":"web_search"}
|
||||
template, _ = sjson.SetRaw(template, "tools.-1", `{"type":"web_search"}`)
|
||||
template, _ = sjson.SetRawBytes(template, "tools.-1", []byte(`{"type":"web_search"}`))
|
||||
continue
|
||||
}
|
||||
tool := toolResult.Raw
|
||||
tool, _ = sjson.Set(tool, "type", "function")
|
||||
tool := []byte(toolResult.Raw)
|
||||
tool, _ = sjson.SetBytes(tool, "type", "function")
|
||||
// Apply shortened name if needed
|
||||
if v := toolResult.Get("name"); v.Exists() {
|
||||
name := v.String()
|
||||
@@ -259,20 +259,26 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
tool, _ = sjson.Set(tool, "name", name)
|
||||
tool, _ = sjson.SetBytes(tool, "name", name)
|
||||
}
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw))
|
||||
tool, _ = sjson.Delete(tool, "input_schema")
|
||||
tool, _ = sjson.Delete(tool, "parameters.$schema")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||
tool, _ = sjson.Set(tool, "strict", false)
|
||||
template, _ = sjson.SetRaw(template, "tools.-1", tool)
|
||||
tool, _ = sjson.SetRawBytes(tool, "parameters", []byte(normalizeToolParameters(toolResult.Get("input_schema").Raw)))
|
||||
tool, _ = sjson.DeleteBytes(tool, "input_schema")
|
||||
tool, _ = sjson.DeleteBytes(tool, "parameters.$schema")
|
||||
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
||||
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
||||
tool, _ = sjson.SetBytes(tool, "strict", false)
|
||||
template, _ = sjson.SetRawBytes(template, "tools.-1", tool)
|
||||
}
|
||||
}
|
||||
|
||||
// Default to parallel tool calls unless tool_choice explicitly disables them.
|
||||
parallelToolCalls := true
|
||||
if disableParallelToolUse := rootResult.Get("tool_choice.disable_parallel_tool_use"); disableParallelToolUse.Exists() {
|
||||
parallelToolCalls = !disableParallelToolUse.Bool()
|
||||
}
|
||||
|
||||
// Add additional configuration parameters for the Codex API.
|
||||
template, _ = sjson.Set(template, "parallel_tool_calls", true)
|
||||
template, _ = sjson.SetBytes(template, "parallel_tool_calls", parallelToolCalls)
|
||||
|
||||
// Convert thinking.budget_tokens to reasoning.effort.
|
||||
reasoningEffort := "medium"
|
||||
@@ -303,13 +309,13 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
|
||||
template, _ = sjson.Set(template, "reasoning.summary", "auto")
|
||||
template, _ = sjson.Set(template, "stream", true)
|
||||
template, _ = sjson.Set(template, "store", false)
|
||||
template, _ = sjson.Set(template, "include", []string{"reasoning.encrypted_content"})
|
||||
template, _ = sjson.SetBytes(template, "reasoning.effort", reasoningEffort)
|
||||
template, _ = sjson.SetBytes(template, "reasoning.summary", "auto")
|
||||
template, _ = sjson.SetBytes(template, "stream", true)
|
||||
template, _ = sjson.SetBytes(template, "store", false)
|
||||
template, _ = sjson.SetBytes(template, "include", []string{"reasoning.encrypted_content"})
|
||||
|
||||
return []byte(template)
|
||||
return template
|
||||
}
|
||||
|
||||
// shortenNameIfNeeded applies a simple shortening rule for a single name.
|
||||
@@ -412,15 +418,15 @@ func normalizeToolParameters(raw string) string {
|
||||
if raw == "" || raw == "null" || !gjson.Valid(raw) {
|
||||
return `{"type":"object","properties":{}}`
|
||||
}
|
||||
schema := raw
|
||||
result := gjson.Parse(raw)
|
||||
schema := []byte(raw)
|
||||
schemaType := result.Get("type").String()
|
||||
if schemaType == "" {
|
||||
schema, _ = sjson.Set(schema, "type", "object")
|
||||
schema, _ = sjson.SetBytes(schema, "type", "object")
|
||||
schemaType = "object"
|
||||
}
|
||||
if schemaType == "object" && !result.Get("properties").Exists() {
|
||||
schema, _ = sjson.SetRaw(schema, "properties", `{}`)
|
||||
schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
|
||||
}
|
||||
return schema
|
||||
return string(schema)
|
||||
}
|
||||
|
||||
@@ -87,3 +87,49 @@ func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToCodex_ParallelToolCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantParallelToolCalls bool
|
||||
}{
|
||||
{
|
||||
name: "Default to true when tool_choice.disable_parallel_tool_use is absent",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantParallelToolCalls: true,
|
||||
},
|
||||
{
|
||||
name: "Disable parallel tool calls when client opts out",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"tool_choice": {"disable_parallel_tool_use": true},
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantParallelToolCalls: false,
|
||||
},
|
||||
{
|
||||
name: "Keep parallel tool calls enabled when client explicitly allows them",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"tool_choice": {"disable_parallel_tool_use": false},
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantParallelToolCalls: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
if got := resultJSON.Get("parallel_tool_calls").Bool(); got != tt.wantParallelToolCalls {
|
||||
t.Fatalf("parallel_tool_calls = %v, want %v. Output: %s", got, tt.wantParallelToolCalls, string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,9 +9,9 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -43,8 +43,8 @@ type ConvertCodexResponseToClaudeParams struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
|
||||
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Claude Code-compatible JSON responses
|
||||
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertCodexResponseToClaudeParams{
|
||||
HasToolCall: false,
|
||||
@@ -54,95 +54,85 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
|
||||
// log.Debugf("rawJSON: %s", string(rawJSON))
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
output := ""
|
||||
output := make([]byte, 0, 512)
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
typeResult := rootResult.Get("type")
|
||||
typeStr := typeResult.String()
|
||||
template := ""
|
||||
var template []byte
|
||||
if typeStr == "response.created" {
|
||||
template = `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`
|
||||
template, _ = sjson.Set(template, "message.model", rootResult.Get("response.model").String())
|
||||
template, _ = sjson.Set(template, "message.id", rootResult.Get("response.id").String())
|
||||
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
|
||||
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
|
||||
template, _ = sjson.SetBytes(template, "message.id", rootResult.Get("response.id").String())
|
||||
|
||||
output = "event: message_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_part.added" {
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output = "event: content_block_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_text.delta" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
|
||||
|
||||
output = "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_part.done" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
|
||||
} else if typeStr == "response.content_part.added" {
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output = "event: content_block_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
} else if typeStr == "response.output_text.delta" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
|
||||
|
||||
output = "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.content_part.done" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
} else if typeStr == "response.completed" {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
||||
stopReason := rootResult.Get("response.stop_reason").String()
|
||||
if p {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
||||
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
|
||||
} else if stopReason == "max_tokens" || stopReason == "stop" {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||
template, _ = sjson.SetBytes(template, "delta.stop_reason", stopReason)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||
template, _ = sjson.SetBytes(template, "delta.stop_reason", "end_turn")
|
||||
}
|
||||
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
|
||||
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens)
|
||||
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.input_tokens", inputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
|
||||
output = "event: message_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output += "event: message_stop\n"
|
||||
output += `data: {"type":"message_stop"}`
|
||||
output += "\n\n"
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "message_delta", template, 2)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "message_stop", []byte(`{"type":"message_stop"}`), 2)
|
||||
} else if typeStr == "response.output_item.added" {
|
||||
itemResult := rootResult.Get("item")
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||
{
|
||||
// Restore original tool name if shortened
|
||||
name := itemResult.Get("name").String()
|
||||
@@ -150,37 +140,33 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
template, _ = sjson.Set(template, "content_block.name", name)
|
||||
template, _ = sjson.SetBytes(template, "content_block.name", name)
|
||||
}
|
||||
|
||||
output = "event: content_block_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
}
|
||||
} else if typeStr == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
}
|
||||
} else if typeStr == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.function_call_arguments.done" {
|
||||
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
|
||||
// in a single "done" event without preceding "delta" events.
|
||||
@@ -189,17 +175,16 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
// When delta events were already received, skip to avoid duplicating arguments.
|
||||
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
|
||||
if args := rootResult.Get("arguments").String(); args != "" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", args)
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []string{output}
|
||||
return [][]byte{output}
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToClaudeNonStream converts a non-streaming Codex response to a non-streaming Claude Code response.
|
||||
@@ -214,28 +199,28 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Claude Code-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
|
||||
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
||||
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
responseData := rootResult.Get("response")
|
||||
if !responseData.Exists() {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
|
||||
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
out, _ = sjson.SetBytes(out, "id", responseData.Get("id").String())
|
||||
out, _ = sjson.SetBytes(out, "model", responseData.Get("model").String())
|
||||
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
|
||||
hasToolCall := false
|
||||
@@ -276,9 +261,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
}
|
||||
}
|
||||
if thinkingBuilder.Len() > 0 {
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
block := []byte(`{"type":"thinking","thinking":""}`)
|
||||
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
}
|
||||
case "message":
|
||||
if content := item.Get("content"); content.Exists() {
|
||||
@@ -287,9 +272,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
if part.Get("type").String() == "output_text" {
|
||||
text := part.Get("text").String()
|
||||
if text != "" {
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
block := []byte(`{"type":"text","text":""}`)
|
||||
block, _ = sjson.SetBytes(block, "text", text)
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -297,9 +282,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
} else {
|
||||
text := content.String()
|
||||
if text != "" {
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
block := []byte(`{"type":"text","text":""}`)
|
||||
block, _ = sjson.SetBytes(block, "text", text)
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -310,9 +295,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
name = original
|
||||
}
|
||||
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
|
||||
inputRaw := "{}"
|
||||
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
@@ -320,23 +305,23 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
inputRaw = argsJSON.Raw
|
||||
}
|
||||
}
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
|
||||
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
|
||||
out, _ = sjson.Set(out, "stop_reason", stopReason.String())
|
||||
out, _ = sjson.SetBytes(out, "stop_reason", stopReason.String())
|
||||
} else if hasToolCall {
|
||||
out, _ = sjson.Set(out, "stop_reason", "tool_use")
|
||||
out, _ = sjson.SetBytes(out, "stop_reason", "tool_use")
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||
out, _ = sjson.SetBytes(out, "stop_reason", "end_turn")
|
||||
}
|
||||
|
||||
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
|
||||
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "stop_sequence", []byte(stopSequence.Raw))
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -386,6 +371,6 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
|
||||
return rev
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.ClaudeInputTokensJSON(count)
|
||||
}
|
||||
|
||||
@@ -6,10 +6,9 @@ package geminiCLI
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
|
||||
"github.com/tidwall/sjson"
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
)
|
||||
|
||||
// ConvertCodexResponseToGeminiCLI converts Codex streaming response format to Gemini CLI format.
|
||||
@@ -24,14 +23,12 @@ import (
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
|
||||
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
outputs := ConvertCodexResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
newOutputs := make([]string, 0)
|
||||
newOutputs := make([][]byte, 0, len(outputs))
|
||||
for i := 0; i < len(outputs); i++ {
|
||||
json := `{"response": {}}`
|
||||
output, _ := sjson.SetRaw(json, "response", outputs[i])
|
||||
newOutputs = append(newOutputs, output)
|
||||
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
|
||||
}
|
||||
return newOutputs
|
||||
}
|
||||
@@ -47,15 +44,12 @@ func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, orig
|
||||
// - param: A pointer to a parameter object for the conversion
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
// log.Debug(string(rawJSON))
|
||||
strJSON := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
json := `{"response": {}}`
|
||||
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
|
||||
return strJSON
|
||||
// - []byte: A Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
out := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
return translatorcommon.WrapGeminiCLIResponse(out)
|
||||
}
|
||||
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ import (
|
||||
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
// Base template
|
||||
out := `{"model":"","instructions":"","input":[]}`
|
||||
out := []byte(`{"model":"","instructions":"","input":[]}`)
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -82,24 +82,24 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
|
||||
// Model
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// System instruction -> as a user message with input_text parts
|
||||
sysParts := root.Get("system_instruction.parts")
|
||||
if sysParts.IsArray() {
|
||||
msg := `{"type":"message","role":"developer","content":[]}`
|
||||
msg := []byte(`{"type":"message","role":"developer","content":[]}`)
|
||||
arr := sysParts.Array()
|
||||
for i := 0; i < len(arr); i++ {
|
||||
p := arr[i]
|
||||
if t := p.Get("text"); t.Exists() {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_text")
|
||||
part, _ = sjson.Set(part, "text", t.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_text")
|
||||
part, _ = sjson.SetBytes(part, "text", t.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
}
|
||||
}
|
||||
if len(gjson.Get(msg, "content").Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
if len(gjson.GetBytes(msg, "content").Array()) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,23 +123,23 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
p := parr[j]
|
||||
// text part
|
||||
if t := p.Get("text"); t.Exists() {
|
||||
msg := `{"type":"message","role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg := []byte(`{"type":"message","role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
partType := "input_text"
|
||||
if role == "assistant" {
|
||||
partType = "output_text"
|
||||
}
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", partType)
|
||||
part, _ = sjson.Set(part, "text", t.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", partType)
|
||||
part, _ = sjson.SetBytes(part, "text", t.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
|
||||
continue
|
||||
}
|
||||
|
||||
// function call from model
|
||||
if fc := p.Get("functionCall"); fc.Exists() {
|
||||
fn := `{"type":"function_call"}`
|
||||
fn := []byte(`{"type":"function_call"}`)
|
||||
if name := fc.Get("name"); name.Exists() {
|
||||
n := name.String()
|
||||
if short, ok := shortMap[n]; ok {
|
||||
@@ -147,31 +147,31 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
n = shortenNameIfNeeded(n)
|
||||
}
|
||||
fn, _ = sjson.Set(fn, "name", n)
|
||||
fn, _ = sjson.SetBytes(fn, "name", n)
|
||||
}
|
||||
if args := fc.Get("args"); args.Exists() {
|
||||
fn, _ = sjson.Set(fn, "arguments", args.Raw)
|
||||
fn, _ = sjson.SetBytes(fn, "arguments", args.Raw)
|
||||
}
|
||||
// generate a paired random call_id and enqueue it so the
|
||||
// corresponding functionResponse can pop the earliest id
|
||||
// to preserve ordering when multiple calls are present.
|
||||
id := genCallID()
|
||||
fn, _ = sjson.Set(fn, "call_id", id)
|
||||
fn, _ = sjson.SetBytes(fn, "call_id", id)
|
||||
pendingCallIDs = append(pendingCallIDs, id)
|
||||
out, _ = sjson.SetRaw(out, "input.-1", fn)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", fn)
|
||||
continue
|
||||
}
|
||||
|
||||
// function response from user
|
||||
if fr := p.Get("functionResponse"); fr.Exists() {
|
||||
fno := `{"type":"function_call_output"}`
|
||||
fno := []byte(`{"type":"function_call_output"}`)
|
||||
// Prefer a string result if present; otherwise embed the raw response as a string
|
||||
if res := fr.Get("response.result"); res.Exists() {
|
||||
fno, _ = sjson.Set(fno, "output", res.String())
|
||||
fno, _ = sjson.SetBytes(fno, "output", res.String())
|
||||
} else if resp := fr.Get("response"); resp.Exists() {
|
||||
fno, _ = sjson.Set(fno, "output", resp.Raw)
|
||||
fno, _ = sjson.SetBytes(fno, "output", resp.Raw)
|
||||
}
|
||||
// fno, _ = sjson.Set(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
|
||||
// fno, _ = sjson.SetBytes(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
|
||||
// attach the oldest queued call_id to pair the response
|
||||
// with its call. If the queue is empty, generate a new id.
|
||||
var id string
|
||||
@@ -182,8 +182,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
id = genCallID()
|
||||
}
|
||||
fno, _ = sjson.Set(fno, "call_id", id)
|
||||
out, _ = sjson.SetRaw(out, "input.-1", fno)
|
||||
fno, _ = sjson.SetBytes(fno, "call_id", id)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", fno)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -193,8 +193,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Tools mapping: Gemini functionDeclarations -> Codex tools
|
||||
tools := root.Get("tools")
|
||||
if tools.IsArray() {
|
||||
out, _ = sjson.SetRaw(out, "tools", `[]`)
|
||||
out, _ = sjson.Set(out, "tool_choice", "auto")
|
||||
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
|
||||
out, _ = sjson.SetBytes(out, "tool_choice", "auto")
|
||||
tarr := tools.Array()
|
||||
for i := 0; i < len(tarr); i++ {
|
||||
td := tarr[i]
|
||||
@@ -205,8 +205,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
farr := fns.Array()
|
||||
for j := 0; j < len(farr); j++ {
|
||||
fn := farr[j]
|
||||
tool := `{}`
|
||||
tool, _ = sjson.Set(tool, "type", "function")
|
||||
tool := []byte(`{}`)
|
||||
tool, _ = sjson.SetBytes(tool, "type", "function")
|
||||
if v := fn.Get("name"); v.Exists() {
|
||||
name := v.String()
|
||||
if short, ok := shortMap[name]; ok {
|
||||
@@ -214,32 +214,32 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
tool, _ = sjson.Set(tool, "name", name)
|
||||
tool, _ = sjson.SetBytes(tool, "name", name)
|
||||
}
|
||||
if v := fn.Get("description"); v.Exists() {
|
||||
tool, _ = sjson.Set(tool, "description", v.String())
|
||||
tool, _ = sjson.SetBytes(tool, "description", v.String())
|
||||
}
|
||||
if prm := fn.Get("parameters"); prm.Exists() {
|
||||
// Remove optional $schema field if present
|
||||
cleaned := prm.Raw
|
||||
cleaned, _ = sjson.Delete(cleaned, "$schema")
|
||||
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
|
||||
cleaned := []byte(prm.Raw)
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
|
||||
tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
|
||||
} else if prm = fn.Get("parametersJsonSchema"); prm.Exists() {
|
||||
// Remove optional $schema field if present
|
||||
cleaned := prm.Raw
|
||||
cleaned, _ = sjson.Delete(cleaned, "$schema")
|
||||
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
|
||||
cleaned := []byte(prm.Raw)
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
|
||||
tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
|
||||
}
|
||||
tool, _ = sjson.Set(tool, "strict", false)
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", tool)
|
||||
tool, _ = sjson.SetBytes(tool, "strict", false)
|
||||
out, _ = sjson.SetRawBytes(out, "tools.-1", tool)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fixed flags aligning with Codex expectations
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
|
||||
|
||||
// Convert Gemini thinkingConfig to Codex reasoning.effort.
|
||||
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
|
||||
@@ -253,7 +253,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
if thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
} else {
|
||||
@@ -263,7 +263,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
}
|
||||
@@ -272,22 +272,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
if !effortSet {
|
||||
// No thinking config, set default effort
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "medium")
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
|
||||
}
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.Set(out, "stream", true)
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
|
||||
out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.SetBytes(out, "stream", true)
|
||||
out, _ = sjson.SetBytes(out, "store", false)
|
||||
out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
|
||||
|
||||
var pathsToLower []string
|
||||
toolsResult := gjson.Get(out, "tools")
|
||||
toolsResult := gjson.GetBytes(out, "tools")
|
||||
util.Walk(toolsResult, "", "type", &pathsToLower)
|
||||
for _, p := range pathsToLower {
|
||||
fullPath := fmt.Sprintf("tools.%s", p)
|
||||
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
|
||||
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// shortenNameIfNeeded applies the simple shortening rule for a single name.
|
||||
|
||||
@@ -7,9 +7,9 @@ package gemini
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -23,7 +23,7 @@ type ConvertCodexResponseToGeminiParams struct {
|
||||
Model string
|
||||
CreatedAt int64
|
||||
ResponseID string
|
||||
LastStorageOutput string
|
||||
LastStorageOutput []byte
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
|
||||
@@ -38,19 +38,19 @@ type ConvertCodexResponseToGeminiParams struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
|
||||
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Gemini-compatible JSON responses
|
||||
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertCodexResponseToGeminiParams{
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
LastStorageOutput: "",
|
||||
LastStorageOutput: nil,
|
||||
}
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
@@ -59,17 +59,17 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
typeStr := typeResult.String()
|
||||
|
||||
// Base Gemini response template
|
||||
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`
|
||||
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" && typeStr == "response.output_item.done" {
|
||||
template = (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput
|
||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`)
|
||||
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 && typeStr == "response.output_item.done" {
|
||||
template = append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
|
||||
createdAtResult := rootResult.Get("response.created_at")
|
||||
if createdAtResult.Exists() {
|
||||
(*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int()
|
||||
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
||||
}
|
||||
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
|
||||
}
|
||||
|
||||
// Handle function call completion
|
||||
@@ -78,7 +78,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
// Create function call part
|
||||
functionCall := `{"functionCall":{"name":"","args":{}}}`
|
||||
functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
|
||||
{
|
||||
// Restore original tool name if shortened
|
||||
n := itemResult.Get("name").String()
|
||||
@@ -86,7 +86,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
if orig, ok := rev[n]; ok {
|
||||
n = orig
|
||||
}
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
|
||||
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
|
||||
}
|
||||
|
||||
// Parse and set arguments
|
||||
@@ -94,47 +94,48 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
if argsStr != "" {
|
||||
argsResult := gjson.Parse(argsStr)
|
||||
if argsResult.IsObject() {
|
||||
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
|
||||
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
|
||||
}
|
||||
}
|
||||
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
|
||||
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = template
|
||||
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
|
||||
|
||||
// Use this return to storage message
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
|
||||
if typeStr == "response.created" { // Handle response creation - set model and response ID
|
||||
template, _ = sjson.Set(template, "modelVersion", rootResult.Get("response.model").String())
|
||||
template, _ = sjson.Set(template, "responseId", rootResult.Get("response.id").String())
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String())
|
||||
template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String())
|
||||
(*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String()
|
||||
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
|
||||
part := `{"thought":true,"text":""}`
|
||||
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
part := []byte(`{"thought":true,"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||
} else if typeStr == "response.completed" { // Handle response completion with usage metadata
|
||||
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
|
||||
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
|
||||
totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int()
|
||||
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||
} else {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" {
|
||||
return []string{(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput, template}
|
||||
} else {
|
||||
return []string{template}
|
||||
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 {
|
||||
return [][]byte{
|
||||
append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...),
|
||||
template,
|
||||
}
|
||||
}
|
||||
|
||||
return [][]byte{template}
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToGeminiNonStream converts a non-streaming Codex response to a non-streaming Gemini response.
|
||||
@@ -149,32 +150,32 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Gemini-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Verify this is a response.completed event
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Base Gemini response template for non-streaming
|
||||
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
|
||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
|
||||
|
||||
// Set model version
|
||||
template, _ = sjson.Set(template, "modelVersion", modelName)
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", modelName)
|
||||
|
||||
// Set response metadata from the completed response
|
||||
responseData := rootResult.Get("response")
|
||||
if responseData.Exists() {
|
||||
// Set response ID
|
||||
if responseId := responseData.Get("id"); responseId.Exists() {
|
||||
template, _ = sjson.Set(template, "responseId", responseId.String())
|
||||
template, _ = sjson.SetBytes(template, "responseId", responseId.String())
|
||||
}
|
||||
|
||||
// Set creation time
|
||||
if createdAt := responseData.Get("created_at"); createdAt.Exists() {
|
||||
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
|
||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
|
||||
}
|
||||
|
||||
// Set usage metadata
|
||||
@@ -183,14 +184,14 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
totalTokens := inputTokens + outputTokens
|
||||
|
||||
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||
}
|
||||
|
||||
// Process output content to build parts array
|
||||
hasToolCall := false
|
||||
var pendingFunctionCalls []string
|
||||
var pendingFunctionCalls [][]byte
|
||||
|
||||
flushPendingFunctionCalls := func() {
|
||||
if len(pendingFunctionCalls) == 0 {
|
||||
@@ -199,7 +200,7 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
// Add all pending function calls as individual parts
|
||||
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
|
||||
for _, fc := range pendingFunctionCalls {
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", fc)
|
||||
}
|
||||
pendingFunctionCalls = nil
|
||||
}
|
||||
@@ -215,9 +216,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Add thinking content
|
||||
if content := value.Get("content"); content.Exists() {
|
||||
part := `{"text":"","thought":true}`
|
||||
part, _ = sjson.Set(part, "text", content.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
part := []byte(`{"text":"","thought":true}`)
|
||||
part, _ = sjson.SetBytes(part, "text", content.String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||
}
|
||||
|
||||
case "message":
|
||||
@@ -229,9 +230,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
content.ForEach(func(_, contentItem gjson.Result) bool {
|
||||
if contentItem.Get("type").String() == "output_text" {
|
||||
if text := contentItem.Get("text"); text.Exists() {
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", text.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", text.String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -241,21 +242,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
case "function_call":
|
||||
// Collect function call for potential merging with consecutive ones
|
||||
hasToolCall = true
|
||||
functionCall := `{"functionCall":{"args":{},"name":""}}`
|
||||
functionCall := []byte(`{"functionCall":{"args":{},"name":""}}`)
|
||||
{
|
||||
n := value.Get("name").String()
|
||||
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
|
||||
if orig, ok := rev[n]; ok {
|
||||
n = orig
|
||||
}
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
|
||||
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
|
||||
}
|
||||
|
||||
// Parse and set arguments
|
||||
if argsStr := value.Get("arguments").String(); argsStr != "" {
|
||||
argsResult := gjson.Parse(argsStr)
|
||||
if argsResult.IsObject() {
|
||||
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
|
||||
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,9 +271,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Set finish reason based on whether there were tool calls
|
||||
if hasToolCall {
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
}
|
||||
}
|
||||
return template
|
||||
@@ -307,6 +308,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
|
||||
return rev
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
@@ -29,42 +29,42 @@ import (
|
||||
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
// Start with empty JSON object
|
||||
out := `{"instructions":""}`
|
||||
out := []byte(`{"instructions":""}`)
|
||||
|
||||
// Stream must be set to true
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
out, _ = sjson.SetBytes(out, "stream", stream)
|
||||
|
||||
// Codex not support temperature, top_p, top_k, max_output_tokens, so comment them
|
||||
// if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "temperature", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "temperature", v.Value())
|
||||
// }
|
||||
// if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "top_p", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "top_p", v.Value())
|
||||
// }
|
||||
// if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "top_k", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "top_k", v.Value())
|
||||
// }
|
||||
|
||||
// Map token limits
|
||||
// if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
|
||||
// }
|
||||
// if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
|
||||
// }
|
||||
|
||||
// Map reasoning effort
|
||||
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", v.Value())
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "medium")
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
|
||||
}
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
|
||||
out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
|
||||
|
||||
// Model
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Build tool name shortening map from original tools (if any)
|
||||
originalToolNameMap := map[string]string{}
|
||||
@@ -100,9 +100,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
// if m.Get("role").String() == "system" {
|
||||
// c := m.Get("content")
|
||||
// if c.Type == gjson.String {
|
||||
// out, _ = sjson.Set(out, "instructions", c.String())
|
||||
// out, _ = sjson.SetBytes(out, "instructions", c.String())
|
||||
// } else if c.IsObject() && c.Get("type").String() == "text" {
|
||||
// out, _ = sjson.Set(out, "instructions", c.Get("text").String())
|
||||
// out, _ = sjson.SetBytes(out, "instructions", c.Get("text").String())
|
||||
// }
|
||||
// break
|
||||
// }
|
||||
@@ -110,7 +110,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
// }
|
||||
|
||||
// Build input from messages, handling all message types including tool calls
|
||||
out, _ = sjson.SetRaw(out, "input", `[]`)
|
||||
out, _ = sjson.SetRawBytes(out, "input", []byte(`[]`))
|
||||
if messages.IsArray() {
|
||||
arr := messages.Array()
|
||||
for i := 0; i < len(arr); i++ {
|
||||
@@ -124,23 +124,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
content := m.Get("content").String()
|
||||
|
||||
// Create function_call_output object
|
||||
funcOutput := `{}`
|
||||
funcOutput, _ = sjson.Set(funcOutput, "type", "function_call_output")
|
||||
funcOutput, _ = sjson.Set(funcOutput, "call_id", toolCallID)
|
||||
funcOutput, _ = sjson.Set(funcOutput, "output", content)
|
||||
out, _ = sjson.SetRaw(out, "input.-1", funcOutput)
|
||||
funcOutput := []byte(`{}`)
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
|
||||
|
||||
default:
|
||||
// Handle regular messages
|
||||
msg := `{}`
|
||||
msg, _ = sjson.Set(msg, "type", "message")
|
||||
msg := []byte(`{}`)
|
||||
msg, _ = sjson.SetBytes(msg, "type", "message")
|
||||
if role == "system" {
|
||||
msg, _ = sjson.Set(msg, "role", "developer")
|
||||
msg, _ = sjson.SetBytes(msg, "role", "developer")
|
||||
} else {
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
}
|
||||
|
||||
msg, _ = sjson.SetRaw(msg, "content", `[]`)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content", []byte(`[]`))
|
||||
|
||||
// Handle regular content
|
||||
c := m.Get("content")
|
||||
@@ -150,10 +150,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
if role == "assistant" {
|
||||
partType = "output_text"
|
||||
}
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", partType)
|
||||
part, _ = sjson.Set(part, "text", c.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", partType)
|
||||
part, _ = sjson.SetBytes(part, "text", c.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
} else if c.Exists() && c.IsArray() {
|
||||
items := c.Array()
|
||||
for j := 0; j < len(items); j++ {
|
||||
@@ -165,39 +165,44 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
if role == "assistant" {
|
||||
partType = "output_text"
|
||||
}
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", partType)
|
||||
part, _ = sjson.Set(part, "text", it.Get("text").String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", partType)
|
||||
part, _ = sjson.SetBytes(part, "text", it.Get("text").String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
case "image_url":
|
||||
// Map image inputs to input_image for Responses API
|
||||
if role == "user" {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_image")
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_image")
|
||||
if u := it.Get("image_url.url"); u.Exists() {
|
||||
part, _ = sjson.Set(part, "image_url", u.String())
|
||||
part, _ = sjson.SetBytes(part, "image_url", u.String())
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
}
|
||||
case "file":
|
||||
if role == "user" {
|
||||
fileData := it.Get("file.file_data").String()
|
||||
filename := it.Get("file.filename").String()
|
||||
if fileData != "" {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_file")
|
||||
part, _ = sjson.Set(part, "file_data", fileData)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_file")
|
||||
part, _ = sjson.SetBytes(part, "file_data", fileData)
|
||||
if filename != "" {
|
||||
part, _ = sjson.Set(part, "filename", filename)
|
||||
part, _ = sjson.SetBytes(part, "filename", filename)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
// Don't emit empty assistant messages when only tool_calls
|
||||
// are present — Responses API needs function_call items
|
||||
// directly, otherwise call_id matching fails (#2132).
|
||||
if role != "assistant" || len(gjson.GetBytes(msg, "content").Array()) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
|
||||
}
|
||||
|
||||
// Handle tool calls for assistant messages as separate top-level objects
|
||||
if role == "assistant" {
|
||||
@@ -208,9 +213,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
tc := toolCallsArr[j]
|
||||
if tc.Get("type").String() == "function" {
|
||||
// Create function_call as top-level object
|
||||
funcCall := `{}`
|
||||
funcCall, _ = sjson.Set(funcCall, "type", "function_call")
|
||||
funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String())
|
||||
funcCall := []byte(`{}`)
|
||||
funcCall, _ = sjson.SetBytes(funcCall, "type", "function_call")
|
||||
funcCall, _ = sjson.SetBytes(funcCall, "call_id", tc.Get("id").String())
|
||||
{
|
||||
name := tc.Get("function.name").String()
|
||||
if short, ok := originalToolNameMap[name]; ok {
|
||||
@@ -218,10 +223,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
funcCall, _ = sjson.Set(funcCall, "name", name)
|
||||
funcCall, _ = sjson.SetBytes(funcCall, "name", name)
|
||||
}
|
||||
funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String())
|
||||
out, _ = sjson.SetRaw(out, "input.-1", funcCall)
|
||||
funcCall, _ = sjson.SetBytes(funcCall, "arguments", tc.Get("function.arguments").String())
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", funcCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -235,26 +240,26 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
text := gjson.GetBytes(rawJSON, "text")
|
||||
if rf.Exists() {
|
||||
// Always create text object when response_format provided
|
||||
if !gjson.Get(out, "text").Exists() {
|
||||
out, _ = sjson.SetRaw(out, "text", `{}`)
|
||||
if !gjson.GetBytes(out, "text").Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
|
||||
}
|
||||
|
||||
rft := rf.Get("type").String()
|
||||
switch rft {
|
||||
case "text":
|
||||
out, _ = sjson.Set(out, "text.format.type", "text")
|
||||
out, _ = sjson.SetBytes(out, "text.format.type", "text")
|
||||
case "json_schema":
|
||||
js := rf.Get("json_schema")
|
||||
if js.Exists() {
|
||||
out, _ = sjson.Set(out, "text.format.type", "json_schema")
|
||||
out, _ = sjson.SetBytes(out, "text.format.type", "json_schema")
|
||||
if v := js.Get("name"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "text.format.name", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text.format.name", v.Value())
|
||||
}
|
||||
if v := js.Get("strict"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "text.format.strict", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text.format.strict", v.Value())
|
||||
}
|
||||
if v := js.Get("schema"); v.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "text.format.schema", []byte(v.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,23 +267,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
// Map verbosity if provided
|
||||
if text.Exists() {
|
||||
if v := text.Get("verbosity"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "text.verbosity", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
|
||||
}
|
||||
}
|
||||
} else if text.Exists() {
|
||||
// If only text.verbosity present (no response_format), map verbosity
|
||||
if v := text.Get("verbosity"); v.Exists() {
|
||||
if !gjson.Get(out, "text").Exists() {
|
||||
out, _ = sjson.SetRaw(out, "text", `{}`)
|
||||
if !gjson.GetBytes(out, "text").Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
|
||||
}
|
||||
out, _ = sjson.Set(out, "text.verbosity", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
|
||||
}
|
||||
}
|
||||
|
||||
// Map tools (flatten function fields)
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "tools", `[]`)
|
||||
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
|
||||
arr := tools.Array()
|
||||
for i := 0; i < len(arr); i++ {
|
||||
t := arr[i]
|
||||
@@ -286,13 +291,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
|
||||
// Only "function" needs structural conversion because Chat Completions nests details under "function".
|
||||
if toolType != "" && toolType != "function" && t.IsObject() {
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "tools.-1", []byte(t.Raw))
|
||||
continue
|
||||
}
|
||||
|
||||
if toolType == "function" {
|
||||
item := `{}`
|
||||
item, _ = sjson.Set(item, "type", "function")
|
||||
item := []byte(`{}`)
|
||||
item, _ = sjson.SetBytes(item, "type", "function")
|
||||
fn := t.Get("function")
|
||||
if fn.Exists() {
|
||||
if v := fn.Get("name"); v.Exists() {
|
||||
@@ -302,19 +307,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
item, _ = sjson.Set(item, "name", name)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
}
|
||||
if v := fn.Get("description"); v.Exists() {
|
||||
item, _ = sjson.Set(item, "description", v.Value())
|
||||
item, _ = sjson.SetBytes(item, "description", v.Value())
|
||||
}
|
||||
if v := fn.Get("parameters"); v.Exists() {
|
||||
item, _ = sjson.SetRaw(item, "parameters", v.Raw)
|
||||
item, _ = sjson.SetRawBytes(item, "parameters", []byte(v.Raw))
|
||||
}
|
||||
if v := fn.Get("strict"); v.Exists() {
|
||||
item, _ = sjson.Set(item, "strict", v.Value())
|
||||
item, _ = sjson.SetBytes(item, "strict", v.Value())
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", item)
|
||||
out, _ = sjson.SetRawBytes(out, "tools.-1", item)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -325,7 +330,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
|
||||
switch {
|
||||
case tc.Type == gjson.String:
|
||||
out, _ = sjson.Set(out, "tool_choice", tc.String())
|
||||
out, _ = sjson.SetBytes(out, "tool_choice", tc.String())
|
||||
case tc.IsObject():
|
||||
tcType := tc.Get("type").String()
|
||||
if tcType == "function" {
|
||||
@@ -337,21 +342,21 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
}
|
||||
choice := `{}`
|
||||
choice, _ = sjson.Set(choice, "type", "function")
|
||||
choice := []byte(`{}`)
|
||||
choice, _ = sjson.SetBytes(choice, "type", "function")
|
||||
if name != "" {
|
||||
choice, _ = sjson.Set(choice, "name", name)
|
||||
choice, _ = sjson.SetBytes(choice, "name", name)
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", choice)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", choice)
|
||||
} else if tcType != "" {
|
||||
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(tc.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
return []byte(out)
|
||||
out, _ = sjson.SetBytes(out, "store", false)
|
||||
return out
|
||||
}
|
||||
|
||||
// shortenNameIfNeeded applies the simple shortening rule for a single name.
|
||||
|
||||
@@ -0,0 +1,635 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Basic tool-call: system + user + assistant(tool_calls, no content) + tool result.
|
||||
// Expects developer msg + user msg + function_call + function_call_output.
|
||||
// No empty assistant message should appear between user and function_call.
|
||||
func TestToolCallSimple(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Paris\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "sunny, 22C"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
// system -> developer
|
||||
if items[0].Get("type").String() != "message" {
|
||||
t.Errorf("item 0: expected type 'message', got '%s'", items[0].Get("type").String())
|
||||
}
|
||||
if items[0].Get("role").String() != "developer" {
|
||||
t.Errorf("item 0: expected role 'developer', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
// user
|
||||
if items[1].Get("type").String() != "message" {
|
||||
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("role").String() != "user" {
|
||||
t.Errorf("item 1: expected role 'user', got '%s'", items[1].Get("role").String())
|
||||
}
|
||||
|
||||
// function_call, not an empty assistant msg
|
||||
if items[2].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
|
||||
}
|
||||
if items[2].Get("call_id").String() != "call_1" {
|
||||
t.Errorf("item 2: expected call_id 'call_1', got '%s'", items[2].Get("call_id").String())
|
||||
}
|
||||
if items[2].Get("name").String() != "get_weather" {
|
||||
t.Errorf("item 2: expected name 'get_weather', got '%s'", items[2].Get("name").String())
|
||||
}
|
||||
if items[2].Get("arguments").String() != `{"city":"Paris"}` {
|
||||
t.Errorf("item 2: unexpected arguments: %s", items[2].Get("arguments").String())
|
||||
}
|
||||
|
||||
// function_call_output
|
||||
if items[3].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
|
||||
}
|
||||
if items[3].Get("call_id").String() != "call_1" {
|
||||
t.Errorf("item 3: expected call_id 'call_1', got '%s'", items[3].Get("call_id").String())
|
||||
}
|
||||
if items[3].Get("output").String() != "sunny, 22C" {
|
||||
t.Errorf("item 3: expected output 'sunny, 22C', got '%s'", items[3].Get("output").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Assistant has both text content and tool_calls — the message should
|
||||
// be emitted (non-empty content), followed by function_call items.
|
||||
func TestToolCallWithContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the weather for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "rainy, 15C"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user + assistant(with content) + function_call + function_call_output
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
if items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
// assistant with content — should be kept
|
||||
if items[1].Get("type").String() != "message" {
|
||||
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("role").String() != "assistant" {
|
||||
t.Errorf("item 1: expected role 'assistant', got '%s'", items[1].Get("role").String())
|
||||
}
|
||||
contentParts := items[1].Get("content").Array()
|
||||
if len(contentParts) == 0 {
|
||||
t.Errorf("item 1: assistant message should have content parts")
|
||||
}
|
||||
|
||||
if items[2].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
|
||||
}
|
||||
if items[2].Get("call_id").String() != "call_abc" {
|
||||
t.Errorf("item 2: expected call_id 'call_abc', got '%s'", items[2].Get("call_id").String())
|
||||
}
|
||||
|
||||
if items[3].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
|
||||
}
|
||||
if items[3].Get("call_id").String() != "call_abc" {
|
||||
t.Errorf("item 3: expected call_id 'call_abc', got '%s'", items[3].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
|
||||
// and outputs must be translated and paired correctly.
|
||||
func TestMultipleToolCalls(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Compare weather in Paris, London and Tokyo"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_paris",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Paris\"}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "call_london",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"London\"}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "call_tokyo",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Tokyo\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_paris", "content": "sunny, 22C"},
|
||||
{"role": "tool", "tool_call_id": "call_london", "content": "cloudy, 14C"},
|
||||
{"role": "tool", "tool_call_id": "call_tokyo", "content": "humid, 28C"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user + 3 function_call + 3 function_call_output = 7
|
||||
if len(items) != 7 {
|
||||
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
if items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
expectedCallIDs := []string{"call_paris", "call_london", "call_tokyo"}
|
||||
for i, expectedID := range expectedCallIDs {
|
||||
idx := i + 1
|
||||
if items[idx].Get("type").String() != "function_call" {
|
||||
t.Errorf("item %d: expected type 'function_call', got '%s'", idx, items[idx].Get("type").String())
|
||||
}
|
||||
if items[idx].Get("call_id").String() != expectedID {
|
||||
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedID, items[idx].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
expectedOutputs := []string{"sunny, 22C", "cloudy, 14C", "humid, 28C"}
|
||||
for i, expectedOutput := range expectedOutputs {
|
||||
idx := i + 4
|
||||
if items[idx].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item %d: expected type 'function_call_output', got '%s'", idx, items[idx].Get("type").String())
|
||||
}
|
||||
if items[idx].Get("call_id").String() != expectedCallIDs[i] {
|
||||
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedCallIDs[i], items[idx].Get("call_id").String())
|
||||
}
|
||||
if items[idx].Get("output").String() != expectedOutput {
|
||||
t.Errorf("item %d: expected output '%s', got '%s'", idx, expectedOutput, items[idx].Get("output").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Regression test for #2132: tool-call-only assistant messages (content:null)
|
||||
// must not produce an empty message item in the translated output.
|
||||
func TestNoSpuriousEmptyAssistantMessage(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Call a tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_x",
|
||||
"type": "function",
|
||||
"function": {"name": "do_thing", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_x", "content": "done"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_thing",
|
||||
"description": "Do a thing",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
for i, item := range items {
|
||||
typ := item.Get("type").String()
|
||||
role := item.Get("role").String()
|
||||
if typ == "message" && role == "assistant" {
|
||||
contentArr := item.Get("content").Array()
|
||||
if len(contentArr) == 0 {
|
||||
t.Errorf("item %d: empty assistant message breaks call_id matching. item: %s", i, item.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should be exactly: user + function_call + function_call_output
|
||||
if len(items) != 3 {
|
||||
t.Fatalf("expected 3 input items (user + function_call + function_call_output), got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
if items[0].Get("type").String() != "message" || items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected user message")
|
||||
}
|
||||
if items[1].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
|
||||
}
|
||||
if items[2].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Two rounds of tool calling in one conversation, with a text reply in between.
|
||||
func TestMultiTurnToolCalling(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{"id": "call_r1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_r1", "content": "sunny"},
|
||||
{"role": "assistant", "content": "It is sunny in Paris."},
|
||||
{"role": "user", "content": "And London?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{"id": "call_r2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"London\"}"}}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_r2", "content": "rainy"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user, func_call(r1), func_output(r1), assistant text, user, func_call(r2), func_output(r2)
|
||||
if len(items) != 7 {
|
||||
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
|
||||
if len(item.Get("content").Array()) == 0 {
|
||||
t.Errorf("item %d: unexpected empty assistant message", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// round 1
|
||||
if items[1].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("call_id").String() != "call_r1" {
|
||||
t.Errorf("item 1: expected call_id 'call_r1', got '%s'", items[1].Get("call_id").String())
|
||||
}
|
||||
if items[2].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
|
||||
}
|
||||
|
||||
// text reply between rounds
|
||||
if items[3].Get("type").String() != "message" || items[3].Get("role").String() != "assistant" {
|
||||
t.Errorf("item 3: expected assistant message, got type=%s role=%s", items[3].Get("type").String(), items[3].Get("role").String())
|
||||
}
|
||||
|
||||
// round 2
|
||||
if items[5].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 5: expected function_call, got %s", items[5].Get("type").String())
|
||||
}
|
||||
if items[5].Get("call_id").String() != "call_r2" {
|
||||
t.Errorf("item 5: expected call_id 'call_r2', got '%s'", items[5].Get("call_id").String())
|
||||
}
|
||||
if items[6].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 6: expected function_call_output, got %s", items[6].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Tool names over 64 chars get shortened, call_id stays the same.
|
||||
func TestToolNameShortening(t *testing.T) {
|
||||
longName := "a_very_long_tool_name_that_exceeds_sixty_four_characters_limit_here_test"
|
||||
if len(longName) <= 64 {
|
||||
t.Fatalf("test setup error: name must be > 64 chars, got %d", len(longName))
|
||||
}
|
||||
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do it"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_long",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "` + longName + `",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_long", "content": "ok"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "` + longName + `",
|
||||
"description": "A tool with a very long name",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
// find function_call
|
||||
var funcCallItem gjson.Result
|
||||
for _, item := range items {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
funcCallItem = item
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !funcCallItem.Exists() {
|
||||
t.Fatal("no function_call item found in output")
|
||||
}
|
||||
|
||||
// call_id unchanged
|
||||
if funcCallItem.Get("call_id").String() != "call_long" {
|
||||
t.Errorf("call_id changed: expected 'call_long', got '%s'", funcCallItem.Get("call_id").String())
|
||||
}
|
||||
|
||||
// name must be truncated
|
||||
translatedName := funcCallItem.Get("name").String()
|
||||
if translatedName == longName {
|
||||
t.Errorf("tool name was NOT shortened: still '%s'", translatedName)
|
||||
}
|
||||
if len(translatedName) > 64 {
|
||||
t.Errorf("shortened name still > 64 chars: len=%d name='%s'", len(translatedName), translatedName)
|
||||
}
|
||||
}
|
||||
|
||||
// content:"" (empty string, not null) should be treated the same as null.
|
||||
func TestEmptyStringContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do something"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_empty",
|
||||
"type": "function",
|
||||
"function": {"name": "action", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_empty", "content": "result"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "action",
|
||||
"description": "An action",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
|
||||
if len(item.Get("content").Array()) == 0 {
|
||||
t.Errorf("item %d: empty assistant message from content:\"\"", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// user + function_call + function_call_output
|
||||
if len(items) != 3 {
|
||||
t.Errorf("expected 3 input items, got %d", len(items))
|
||||
}
|
||||
}
|
||||
|
||||
// Every function_call_output must have a matching function_call by call_id.
|
||||
func TestCallIDsMatchBetweenCallAndOutput(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Multi-tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{"id": "id_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}},
|
||||
{"id": "id_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "id_a", "content": "res_a"},
|
||||
{"role": "tool", "tool_call_id": "id_b", "content": "res_b"}
|
||||
],
|
||||
"tools": [
|
||||
{"type": "function", "function": {"name": "tool_a", "description": "A", "parameters": {"type": "object", "properties": {}}}},
|
||||
{"type": "function", "function": {"name": "tool_b", "description": "B", "parameters": {"type": "object", "properties": {}}}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
// collect call_ids from function_call items
|
||||
callIDs := make(map[string]bool)
|
||||
for _, item := range items {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
callIDs[item.Get("call_id").String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "function_call_output" {
|
||||
outID := item.Get("call_id").String()
|
||||
if !callIDs[outID] {
|
||||
t.Errorf("item %d: function_call_output has call_id '%s' with no matching function_call", i, outID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2 calls, 2 outputs
|
||||
funcCallCount := 0
|
||||
funcOutputCount := 0
|
||||
for _, item := range items {
|
||||
switch item.Get("type").String() {
|
||||
case "function_call":
|
||||
funcCallCount++
|
||||
case "function_call_output":
|
||||
funcOutputCount++
|
||||
}
|
||||
}
|
||||
if funcCallCount != 2 {
|
||||
t.Errorf("expected 2 function_calls, got %d", funcCallCount)
|
||||
}
|
||||
if funcOutputCount != 2 {
|
||||
t.Errorf("expected 2 function_call_outputs, got %d", funcOutputCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Tools array should carry over to the Responses format output.
|
||||
func TestToolsDefinitionTranslated(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hi"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
tools := gjson.Get(result, "tools").Array()
|
||||
if len(tools) == 0 {
|
||||
t.Fatal("no tools found in output")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, tool := range tools {
|
||||
if tool.Get("name").String() == "search" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("tool 'search' not found in output tools: %s", gjson.Get(result, "tools").Raw)
|
||||
}
|
||||
}
|
||||
@@ -41,8 +41,8 @@ type ConvertCliToOpenAIParams struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of OpenAI-compatible JSON responses
|
||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertCliToOpenAIParams{
|
||||
Model: modelName,
|
||||
@@ -55,12 +55,12 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -70,67 +70,67 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
(*param).(*ConvertCliToOpenAIParams).ResponseID = rootResult.Get("response.id").String()
|
||||
(*param).(*ConvertCliToOpenAIParams).CreatedAt = rootResult.Get("response.created_at").Int()
|
||||
(*param).(*ConvertCliToOpenAIParams).Model = rootResult.Get("response.model").String()
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Extract and set the model version.
|
||||
cachedModel := (*param).(*ConvertCliToOpenAIParams).Model
|
||||
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||
template, _ = sjson.SetBytes(template, "model", modelResult.String())
|
||||
} else if cachedModel != "" {
|
||||
template, _ = sjson.Set(template, "model", cachedModel)
|
||||
template, _ = sjson.SetBytes(template, "model", cachedModel)
|
||||
} else if modelName != "" {
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
}
|
||||
|
||||
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
|
||||
|
||||
// Extract and set the response ID.
|
||||
template, _ = sjson.Set(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() {
|
||||
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||
}
|
||||
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
|
||||
}
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
}
|
||||
|
||||
if dataType == "response.reasoning_summary_text.delta" {
|
||||
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", deltaResult.String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", deltaResult.String())
|
||||
}
|
||||
} else if dataType == "response.reasoning_summary_text.done" {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", "\n\n")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", "\n\n")
|
||||
} else if dataType == "response.output_text.delta" {
|
||||
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", deltaResult.String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.content", deltaResult.String())
|
||||
}
|
||||
} else if dataType == "response.completed" {
|
||||
finishReason := "stop"
|
||||
if (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex != -1 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
|
||||
} else if dataType == "response.output_item.added" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Increment index for this new function call item.
|
||||
@@ -138,9 +138,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
@@ -148,59 +148,59 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", "")
|
||||
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
|
||||
|
||||
deltaValue := rootResult.Get("delta").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||
functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.done" {
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
|
||||
// Arguments were already streamed via delta events; nothing to emit.
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Fallback: no delta events were received, emit the full arguments as a single chunk.
|
||||
fullArgs := rootResult.Get("arguments").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||
functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
|
||||
// Tool call was already announced via output_item.added; skip emission.
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Fallback path: model skipped output_item.added, so emit complete tool call now.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
@@ -208,17 +208,17 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToOpenAINonStream converts a non-streaming Codex response to a non-streaming OpenAI response.
|
||||
@@ -233,53 +233,53 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
// Verify this is a response.completed event
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
unixTimestamp := time.Now().Unix()
|
||||
|
||||
responseResult := rootResult.Get("response")
|
||||
|
||||
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelResult := responseResult.Get("model"); modelResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||
template, _ = sjson.SetBytes(template, "model", modelResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
if createdAtResult := responseResult.Get("created_at"); createdAtResult.Exists() {
|
||||
template, _ = sjson.Set(template, "created", createdAtResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "created", createdAtResult.Int())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
template, _ = sjson.SetBytes(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if idResult := responseResult.Get("id"); idResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", idResult.String())
|
||||
template, _ = sjson.SetBytes(template, "id", idResult.String())
|
||||
}
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := responseResult.Get("usage"); usageResult.Exists() {
|
||||
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||
}
|
||||
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
|
||||
}
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,7 +289,7 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
outputArray := outputResult.Array()
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []string
|
||||
var toolCalls [][]byte
|
||||
|
||||
for _, outputItem := range outputArray {
|
||||
outputType := outputItem.Get("type").String()
|
||||
@@ -319,10 +319,10 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
}
|
||||
case "function_call":
|
||||
// Handle function call content
|
||||
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
functionCallTemplate := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||
|
||||
if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIdResult.String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", callIdResult.String())
|
||||
}
|
||||
|
||||
if nameResult := outputItem.Get("name"); nameResult.Exists() {
|
||||
@@ -331,11 +331,11 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if orig, ok := rev[n]; ok {
|
||||
n = orig
|
||||
}
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", n)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", n)
|
||||
}
|
||||
|
||||
if argsResult := outputItem.Get("arguments"); argsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", argsResult.String())
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, functionCallTemplate)
|
||||
@@ -344,22 +344,22 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
|
||||
// Set content and reasoning content if found
|
||||
if contentText != "" {
|
||||
template, _ = sjson.Set(template, "choices.0.message.content", contentText)
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.content", contentText)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
|
||||
}
|
||||
|
||||
if reasoningText != "" {
|
||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText)
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.reasoning_content", reasoningText)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
|
||||
}
|
||||
|
||||
// Add tool calls if any
|
||||
if len(toolCalls) > 0 {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls", []byte(`[]`))
|
||||
for _, toolCall := range toolCalls {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls.-1", toolCall)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,8 +367,8 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if statusResult := responseResult.Get("status"); statusResult.Exists() {
|
||||
status := statusResult.String()
|
||||
if status == "completed" {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "stop")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *test
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
gotModel := gjson.Get(out[0], "model").String()
|
||||
gotModel := gjson.GetBytes(out[0], "model").String()
|
||||
if gotModel != modelName {
|
||||
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||
}
|
||||
@@ -40,8 +40,53 @@ func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
gotModel := gjson.Get(out[0], "model").String()
|
||||
gotModel := gjson.GetBytes(out[0], "model").String()
|
||||
if gotModel != modelName {
|
||||
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToOpenAI_ToolCallChunkOmitsNullContentFields(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), ¶m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
|
||||
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
|
||||
}
|
||||
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
|
||||
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
|
||||
}
|
||||
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls").Exists() {
|
||||
t.Fatalf("expected tool_calls to exist, got %s", string(out[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToOpenAI_ToolCallArgumentsDeltaOmitsNullContentFields(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), ¶m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected tool call announcement chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
out = ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.function_call_arguments.delta","delta":"{\"query\":\"OpenAI\"}"}`), ¶m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
|
||||
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
|
||||
}
|
||||
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
|
||||
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
|
||||
}
|
||||
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls.0.function.arguments").Exists() {
|
||||
t.Fatalf("expected tool call arguments delta to exist, got %s", string(out[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package responses
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -12,8 +13,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
|
||||
inputResult := gjson.GetBytes(rawJSON, "input")
|
||||
if inputResult.Type == gjson.String {
|
||||
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
|
||||
input, _ := sjson.SetBytes([]byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`), "0.content.0.text", inputResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", input)
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
|
||||
@@ -39,6 +40,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
|
||||
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
|
||||
rawJSON = convertSystemRoleToDeveloper(rawJSON)
|
||||
rawJSON = normalizeCodexBuiltinTools(rawJSON)
|
||||
|
||||
return rawJSON
|
||||
}
|
||||
@@ -82,3 +84,59 @@ func convertSystemRoleToDeveloper(rawJSON []byte) []byte {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// normalizeCodexBuiltinTools rewrites legacy/preview built-in tool variants to the
|
||||
// stable names expected by the current Codex upstream.
|
||||
func normalizeCodexBuiltinTools(rawJSON []byte) []byte {
|
||||
result := rawJSON
|
||||
|
||||
tools := gjson.GetBytes(result, "tools")
|
||||
if tools.IsArray() {
|
||||
toolArray := tools.Array()
|
||||
for i := 0; i < len(toolArray); i++ {
|
||||
typePath := fmt.Sprintf("tools.%d.type", i)
|
||||
result = normalizeCodexBuiltinToolAtPath(result, typePath)
|
||||
}
|
||||
}
|
||||
|
||||
result = normalizeCodexBuiltinToolAtPath(result, "tool_choice.type")
|
||||
|
||||
toolChoiceTools := gjson.GetBytes(result, "tool_choice.tools")
|
||||
if toolChoiceTools.IsArray() {
|
||||
toolArray := toolChoiceTools.Array()
|
||||
for i := 0; i < len(toolArray); i++ {
|
||||
typePath := fmt.Sprintf("tool_choice.tools.%d.type", i)
|
||||
result = normalizeCodexBuiltinToolAtPath(result, typePath)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeCodexBuiltinToolAtPath(rawJSON []byte, path string) []byte {
|
||||
currentType := gjson.GetBytes(rawJSON, path).String()
|
||||
normalizedType := normalizeCodexBuiltinToolType(currentType)
|
||||
if normalizedType == "" {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
updated, err := sjson.SetBytes(rawJSON, path, normalizedType)
|
||||
if err != nil {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
log.Debugf("codex responses: normalized builtin tool type at %s from %q to %q", path, currentType, normalizedType)
|
||||
return updated
|
||||
}
|
||||
|
||||
// normalizeCodexBuiltinToolType centralizes the current known Codex Responses
|
||||
// built-in tool alias compatibility. If Codex introduces more legacy aliases,
|
||||
// extend this helper instead of adding path-specific rewrite logic elsewhere.
|
||||
func normalizeCodexBuiltinToolType(toolType string) string {
|
||||
switch toolType {
|
||||
case "web_search_preview", "web_search_preview_2025_03_11":
|
||||
return "web_search"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,6 +264,52 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIResponsesRequestToCodex_NormalizesWebSearchPreview(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.4-mini",
|
||||
"input": "find latest OpenAI model news",
|
||||
"tools": [
|
||||
{"type": "web_search_preview_2025_03_11"}
|
||||
],
|
||||
"tool_choice": {
|
||||
"type": "allowed_tools",
|
||||
"tools": [
|
||||
{"type": "web_search_preview"},
|
||||
{"type": "web_search_preview_2025_03_11"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
|
||||
|
||||
if got := gjson.GetBytes(output, "tools.0.type").String(); got != "web_search" {
|
||||
t.Fatalf("tools.0.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "allowed_tools" {
|
||||
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "allowed_tools", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.tools.0.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.tools.0.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.tools.1.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.tools.1.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIResponsesRequestToCodex_NormalizesTopLevelToolChoicePreviewAlias(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.4-mini",
|
||||
"input": "find latest OpenAI model news",
|
||||
"tool_choice": {"type": "web_search_preview_2025_03_11"}
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
|
||||
|
||||
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
|
||||
@@ -3,7 +3,6 @@ package responses
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -11,23 +10,25 @@ import (
|
||||
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
||||
// to OpenAI Responses SSE events (response.*).
|
||||
|
||||
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
|
||||
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) [][]byte {
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
out := fmt.Sprintf("data: %s", string(rawJSON))
|
||||
return []string{out}
|
||||
out := make([]byte, 0, len(rawJSON)+len("data: "))
|
||||
out = append(out, []byte("data: ")...)
|
||||
out = append(out, rawJSON...)
|
||||
return [][]byte{out}
|
||||
}
|
||||
return []string{string(rawJSON)}
|
||||
return [][]byte{rawJSON}
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
|
||||
// from a non-streaming OpenAI Chat Completions response.
|
||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
|
||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []byte {
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
// Verify this is a response.completed event
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
responseResult := rootResult.Get("response")
|
||||
return responseResult.Raw
|
||||
return []byte(responseResult.Raw)
|
||||
}
|
||||
|
||||
67
internal/translator/common/bytes.go
Normal file
67
internal/translator/common/bytes.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func WrapGeminiCLIResponse(response []byte) []byte {
|
||||
out, err := sjson.SetRawBytes([]byte(`{"response":{}}`), "response", response)
|
||||
if err != nil {
|
||||
return response
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GeminiTokenCountJSON(count int64) []byte {
|
||||
out := make([]byte, 0, 96)
|
||||
out = append(out, `{"totalTokens":`...)
|
||||
out = strconv.AppendInt(out, count, 10)
|
||||
out = append(out, `,"promptTokensDetails":[{"modality":"TEXT","tokenCount":`...)
|
||||
out = strconv.AppendInt(out, count, 10)
|
||||
out = append(out, `}]}`...)
|
||||
return out
|
||||
}
|
||||
|
||||
func ClaudeInputTokensJSON(count int64) []byte {
|
||||
out := make([]byte, 0, 32)
|
||||
out = append(out, `{"input_tokens":`...)
|
||||
out = strconv.AppendInt(out, count, 10)
|
||||
out = append(out, '}')
|
||||
return out
|
||||
}
|
||||
|
||||
func SSEEventData(event string, payload []byte) []byte {
|
||||
out := make([]byte, 0, len(event)+len(payload)+14)
|
||||
out = append(out, "event: "...)
|
||||
out = append(out, event...)
|
||||
out = append(out, '\n')
|
||||
out = append(out, "data: "...)
|
||||
out = append(out, payload...)
|
||||
return out
|
||||
}
|
||||
|
||||
func AppendSSEEventString(out []byte, event, payload string, trailingNewlines int) []byte {
|
||||
out = append(out, "event: "...)
|
||||
out = append(out, event...)
|
||||
out = append(out, '\n')
|
||||
out = append(out, "data: "...)
|
||||
out = append(out, payload...)
|
||||
for i := 0; i < trailingNewlines; i++ {
|
||||
out = append(out, '\n')
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AppendSSEEventBytes(out []byte, event string, payload []byte, trailingNewlines int) []byte {
|
||||
out = append(out, "event: "...)
|
||||
out = append(out, event...)
|
||||
out = append(out, '\n')
|
||||
out = append(out, "data: "...)
|
||||
out = append(out, payload...)
|
||||
for i := 0; i < trailingNewlines; i++ {
|
||||
out = append(out, '\n')
|
||||
}
|
||||
return out
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user