mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-04 03:31:21 +00:00
Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fa527193c | ||
|
|
ed0eb51b4d | ||
|
|
0e4f669c8b | ||
|
|
76c064c729 | ||
|
|
d2f652f436 | ||
|
|
6a452a54d5 | ||
|
|
9e5693e74f | ||
|
|
528b1a2307 | ||
|
|
0cc978ec1d | ||
|
|
fee736933b | ||
|
|
5c99846ecf | ||
|
|
d475aaba96 | ||
|
|
1dc4ecb1b8 | ||
|
|
1315f710f5 | ||
|
|
96f55570f7 | ||
|
|
0906aeca87 | ||
|
|
97c0487add | ||
|
|
a576088d5f | ||
|
|
66ff916838 | ||
|
|
7b0453074e | ||
|
|
a000eb523d | ||
|
|
18a4fedc7f | ||
|
|
5d6cdccda0 | ||
|
|
1b7f4ac3e1 | ||
|
|
afc1a5b814 | ||
|
|
7ed38db54f | ||
|
|
28c10f4e69 | ||
|
|
6e12441a3b | ||
|
|
65c439c18d | ||
|
|
0ed2d16596 | ||
|
|
db335ac616 | ||
|
|
e6690cb447 | ||
|
|
35907416b8 | ||
|
|
e8bb350467 | ||
|
|
5331d51f27 | ||
|
|
755ca75879 | ||
|
|
2398ebad55 | ||
|
|
c1bf298216 | ||
|
|
e005208d76 | ||
|
|
d1df70d02f | ||
|
|
52c1fa025e | ||
|
|
680105f84d | ||
|
|
f7069e9548 | ||
|
|
7275e99b41 | ||
|
|
c28b65f849 | ||
|
|
793840cdb4 | ||
|
|
8f421de532 | ||
|
|
be2dd60ee7 | ||
|
|
ea3e0b713e | ||
|
|
8179d5a8a4 | ||
|
|
6fa7abe434 | ||
|
|
5135c22cd6 | ||
|
|
1e27990561 | ||
|
|
e1e9fc43c1 | ||
|
|
b2921518ac | ||
|
|
dd64adbeeb | ||
|
|
616d41c06a | ||
|
|
e0e337aeb9 | ||
|
|
d52839fced | ||
|
|
4022e69651 | ||
|
|
c3762328a5 |
@@ -1,6 +1,6 @@
|
||||
# CLIProxyAPI Plus
|
||||
|
||||
[English](README.md) | 中文
|
||||
[English](README.md) | 中文 | [日本語](README_JA.md)
|
||||
|
||||
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
||||
|
||||
|
||||
183
README_JA.md
Normal file
183
README_JA.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# CLI Proxy API
|
||||
|
||||
[English](README.md) | [中文](README_CN.md) | 日本語
|
||||
|
||||
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
|
||||
|
||||
OAuth経由でOpenAI Codex(GPTモデル)およびClaude Codeもサポートしています。
|
||||
|
||||
ローカルまたはマルチアカウントのCLIアクセスを、OpenAI(Responses含む)/Gemini/Claude互換のクライアントやSDKで利用できます。
|
||||
|
||||
## スポンサー
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
|
||||
|
||||
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7および(GLM-5はProユーザーのみ利用可能)モデルを10以上の人気AIコーディングツール(Claude Code、Cline、Roo Codeなど)で利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
|
||||
|
||||
GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
|
||||
---
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## 概要
|
||||
|
||||
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
|
||||
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
|
||||
- OAuthログインによるClaude Codeサポート
|
||||
- OAuthログインによるQwen Codeサポート
|
||||
- OAuthログインによるiFlowサポート
|
||||
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
|
||||
- ストリーミングおよび非ストリーミングレスポンス
|
||||
- 関数呼び出し/ツールのサポート
|
||||
- マルチモーダル入力サポート(テキストと画像)
|
||||
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||
- シンプルなCLI認証フロー(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||
- Generative Language APIキーのサポート
|
||||
- AI Studioビルドのマルチアカウント負荷分散
|
||||
- Gemini CLIのマルチアカウント負荷分散
|
||||
- Claude Codeのマルチアカウント負荷分散
|
||||
- Qwen Codeのマルチアカウント負荷分散
|
||||
- iFlowのマルチアカウント負荷分散
|
||||
- OpenAI Codexのマルチアカウント負荷分散
|
||||
- 設定によるOpenAI互換アップストリームプロバイダー(例:OpenRouter)
|
||||
- プロキシ埋め込み用の再利用可能なGo SDK(`docs/sdk-usage.md`を参照)
|
||||
|
||||
## はじめに
|
||||
|
||||
CLIProxyAPIガイド:[https://help.router-for.me/](https://help.router-for.me/)
|
||||
|
||||
## 管理API
|
||||
|
||||
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
|
||||
|
||||
## Amp CLIサポート
|
||||
|
||||
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます:
|
||||
|
||||
- Ampの APIパターン用のプロバイダールートエイリアス(`/api/provider/{provider}/v1...`)
|
||||
- OAuth認証およびアカウント機能用の管理プロキシ
|
||||
- 自動ルーティングによるスマートモデルフォールバック
|
||||
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
||||
|
||||
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDKドキュメント
|
||||
|
||||
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
|
||||
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
|
||||
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
|
||||
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
|
||||
- カスタムプロバイダーの例:`examples/custom-provider`
|
||||
|
||||
## コントリビューション
|
||||
|
||||
コントリビューションを歓迎します!お気軽にPull Requestを送ってください。
|
||||
|
||||
1. リポジトリをフォーク
|
||||
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`)
|
||||
3. 変更をコミット(`git commit -m 'Add some amazing feature'`)
|
||||
4. ブランチにプッシュ(`git push origin feature/amazing-feature`)
|
||||
5. Pull Requestを作成
|
||||
|
||||
## 関連プロジェクト
|
||||
|
||||
CLIProxyAPIをベースにした以下のプロジェクトがあります:
|
||||
|
||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
||||
|
||||
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
|
||||
|
||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
||||
|
||||
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
|
||||
|
||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
||||
|
||||
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデル(Gemini、Codex、Antigravity)を即座に切り替えるCLIラッパー - APIキー不要
|
||||
|
||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
||||
|
||||
CLIProxyAPI管理用のmacOSネイティブGUI:OAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
|
||||
|
||||
### [CodMate](https://github.com/loocor/CodMate)
|
||||
|
||||
CLI AIセッション(Codex、Claude Code、Gemini CLI)を管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
|
||||
|
||||
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||
|
||||
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
|
||||
|
||||
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||
|
||||
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
|
||||
|
||||
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
|
||||
|
||||
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
|
||||
|
||||
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||
|
||||
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
|
||||
|
||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||
|
||||
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替え(Main / Plus)、自動ダウンロードおよび自動更新に対応
|
||||
|
||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||
|
||||
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
|
||||
|
||||
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
|
||||
|
||||
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
|
||||
|
||||
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||
|
||||
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
|
||||
|
||||
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
|
||||
|
||||
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LAN(ローカルエリアネットワーク)を介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
## その他の選択肢
|
||||
|
||||
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです:
|
||||
|
||||
### [9Router](https://github.com/decolua/9router)
|
||||
|
||||
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換(OpenAI/Claude/Gemini/Ollama)、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツール(Cursor、Claude Code、Cline、RooCode)のサポートをゼロから構築 - APIキー不要
|
||||
|
||||
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
|
||||
|
||||
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
|
||||
|
||||
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:スマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
## ライセンス
|
||||
|
||||
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。
|
||||
@@ -95,6 +95,7 @@ func main() {
|
||||
var kiroIDCRegion string
|
||||
var kiroIDCFlow string
|
||||
var githubCopilotLogin bool
|
||||
var codeBuddyLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
@@ -132,6 +133,7 @@ func main() {
|
||||
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
||||
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||
flag.BoolVar(&codeBuddyLogin, "codebuddy-login", false, "Login to CodeBuddy using browser OAuth flow")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
@@ -516,6 +518,9 @@ func main() {
|
||||
} else if githubCopilotLogin {
|
||||
// Handle GitHub Copilot login
|
||||
cmd.DoGitHubCopilotLogin(cfg, options)
|
||||
} else if codeBuddyLogin {
|
||||
// Handle CodeBuddy login
|
||||
cmd.DoCodeBuddyLogin(cfg, options)
|
||||
} else if codexLogin {
|
||||
// Handle Codex login
|
||||
cmd.DoCodexLogin(cfg, options)
|
||||
|
||||
@@ -175,12 +175,19 @@ nonstream-keepalive-interval: 0
|
||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
||||
# when the client omits them, while OS/arch remain runtime-derived. When
|
||||
# stabilize-device-profile is enabled, OS/arch stay pinned to the baseline values below,
|
||||
# while user-agent/package-version/runtime-version seed a software fingerprint that can
|
||||
# still upgrade to newer official Claude client versions.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# os: "MacOS"
|
||||
# arch: "arm64"
|
||||
# timeout: "600"
|
||||
# stabilize-device-profile: false # optional, default false; set true to enable per-auth/API-key fingerprint pinning
|
||||
|
||||
# Default headers for Codex OAuth model requests.
|
||||
# These are used only for file-backed/OAuth Codex requests when the client
|
||||
@@ -231,7 +238,9 @@ nonstream-keepalive-interval: 0
|
||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||
# models: # The models supported by the provider.
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# thinking: # optional: omit to default to levels ["low","medium","high"]
|
||||
# levels: ["low", "medium", "high"]
|
||||
# # You may repeat the same alias to build an internal model pool.
|
||||
# # The client still sees only one alias in the model list.
|
||||
# # Requests to that alias will round-robin across the upstream names below,
|
||||
|
||||
2
go.mod
2
go.mod
@@ -91,8 +91,8 @@ require (
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -66,8 +67,10 @@ type callbackForwarder struct {
|
||||
}
|
||||
|
||||
var (
|
||||
callbackForwardersMu sync.Mutex
|
||||
callbackForwarders = make(map[int]*callbackForwarder)
|
||||
callbackForwardersMu sync.Mutex
|
||||
callbackForwarders = make(map[int]*callbackForwarder)
|
||||
errAuthFileMustBeJSON = errors.New("auth file must be .json")
|
||||
errAuthFileNotFound = errors.New("auth file not found")
|
||||
)
|
||||
|
||||
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
|
||||
@@ -579,32 +582,57 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
if file, err := c.FormFile("file"); err == nil && file != nil {
|
||||
name := filepath.Base(file.Filename)
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
c.JSON(400, gin.H{"error": "file must be .json"})
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(h.cfg.AuthDir, name)
|
||||
if !filepath.IsAbs(dst) {
|
||||
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
|
||||
dst = abs
|
||||
|
||||
fileHeaders, errMultipart := h.multipartAuthFileHeaders(c)
|
||||
if errMultipart != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid multipart form: %v", errMultipart)})
|
||||
return
|
||||
}
|
||||
if len(fileHeaders) == 1 {
|
||||
if _, errUpload := h.storeUploadedAuthFile(ctx, fileHeaders[0]); errUpload != nil {
|
||||
if errors.Is(errUpload, errAuthFileMustBeJSON) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "file must be .json"})
|
||||
return
|
||||
}
|
||||
}
|
||||
if errSave := c.SaveUploadedFile(file, dst); errSave != nil {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errUpload.Error()})
|
||||
return
|
||||
}
|
||||
data, errRead := os.ReadFile(dst)
|
||||
if errRead != nil {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)})
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
if len(fileHeaders) > 1 {
|
||||
uploaded := make([]string, 0, len(fileHeaders))
|
||||
failed := make([]gin.H, 0)
|
||||
for _, file := range fileHeaders {
|
||||
name, errUpload := h.storeUploadedAuthFile(ctx, file)
|
||||
if errUpload != nil {
|
||||
failureName := ""
|
||||
if file != nil {
|
||||
failureName = filepath.Base(file.Filename)
|
||||
}
|
||||
msg := errUpload.Error()
|
||||
if errors.Is(errUpload, errAuthFileMustBeJSON) {
|
||||
msg = "file must be .json"
|
||||
}
|
||||
failed = append(failed, gin.H{"name": failureName, "error": msg})
|
||||
continue
|
||||
}
|
||||
uploaded = append(uploaded, name)
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
c.JSON(http.StatusMultiStatus, gin.H{
|
||||
"status": "partial",
|
||||
"uploaded": len(uploaded),
|
||||
"files": uploaded,
|
||||
"failed": failed,
|
||||
})
|
||||
return
|
||||
}
|
||||
if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil {
|
||||
c.JSON(500, gin.H{"error": errReg.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "uploaded": len(uploaded), "files": uploaded})
|
||||
return
|
||||
}
|
||||
if c.ContentType() == "multipart/form-data" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no files uploaded"})
|
||||
return
|
||||
}
|
||||
name := c.Query("name")
|
||||
@@ -621,17 +649,7 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
if !filepath.IsAbs(dst) {
|
||||
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
|
||||
dst = abs
|
||||
}
|
||||
}
|
||||
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)})
|
||||
return
|
||||
}
|
||||
if err = h.registerAuthFromFile(ctx, dst, data); err != nil {
|
||||
if err = h.writeAuthFile(ctx, filepath.Base(name), data); err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -678,11 +696,182 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "deleted": deleted})
|
||||
return
|
||||
}
|
||||
name := c.Query("name")
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
|
||||
names, errNames := requestedAuthFileNamesForDelete(c)
|
||||
if errNames != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errNames.Error()})
|
||||
return
|
||||
}
|
||||
if len(names) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
if len(names) == 1 {
|
||||
if _, status, errDelete := h.deleteAuthFileByName(ctx, names[0]); errDelete != nil {
|
||||
c.JSON(status, gin.H{"error": errDelete.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
|
||||
deletedFiles := make([]string, 0, len(names))
|
||||
failed := make([]gin.H, 0)
|
||||
for _, name := range names {
|
||||
deletedName, _, errDelete := h.deleteAuthFileByName(ctx, name)
|
||||
if errDelete != nil {
|
||||
failed = append(failed, gin.H{"name": name, "error": errDelete.Error()})
|
||||
continue
|
||||
}
|
||||
deletedFiles = append(deletedFiles, deletedName)
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
c.JSON(http.StatusMultiStatus, gin.H{
|
||||
"status": "partial",
|
||||
"deleted": len(deletedFiles),
|
||||
"files": deletedFiles,
|
||||
"failed": failed,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "deleted": len(deletedFiles), "files": deletedFiles})
|
||||
}
|
||||
|
||||
func (h *Handler) multipartAuthFileHeaders(c *gin.Context) ([]*multipart.FileHeader, error) {
|
||||
if h == nil || c == nil || c.ContentType() != "multipart/form-data" {
|
||||
return nil, nil
|
||||
}
|
||||
form, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if form == nil || len(form.File) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(form.File))
|
||||
for key := range form.File {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
headers := make([]*multipart.FileHeader, 0)
|
||||
for _, key := range keys {
|
||||
headers = append(headers, form.File[key]...)
|
||||
}
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
func (h *Handler) storeUploadedAuthFile(ctx context.Context, file *multipart.FileHeader) (string, error) {
|
||||
if file == nil {
|
||||
return "", fmt.Errorf("no file uploaded")
|
||||
}
|
||||
name := filepath.Base(strings.TrimSpace(file.Filename))
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
return "", errAuthFileMustBeJSON
|
||||
}
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open uploaded file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
data, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read uploaded file: %w", err)
|
||||
}
|
||||
if err := h.writeAuthFile(ctx, name, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func (h *Handler) writeAuthFile(ctx context.Context, name string, data []byte) error {
|
||||
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
if !filepath.IsAbs(dst) {
|
||||
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
|
||||
dst = abs
|
||||
}
|
||||
}
|
||||
auth, err := h.buildAuthFromFileData(dst, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
|
||||
return fmt.Errorf("failed to write file: %w", errWrite)
|
||||
}
|
||||
if err := h.upsertAuthRecord(ctx, auth); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func requestedAuthFileNamesForDelete(c *gin.Context) ([]string, error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
names := uniqueAuthFileNames(c.QueryArray("name"))
|
||||
if len(names) > 0 {
|
||||
return names, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body")
|
||||
}
|
||||
body = bytes.TrimSpace(body)
|
||||
if len(body) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var objectBody struct {
|
||||
Name string `json:"name"`
|
||||
Names []string `json:"names"`
|
||||
}
|
||||
if body[0] == '[' {
|
||||
var arrayBody []string
|
||||
if err := json.Unmarshal(body, &arrayBody); err != nil {
|
||||
return nil, fmt.Errorf("invalid request body")
|
||||
}
|
||||
return uniqueAuthFileNames(arrayBody), nil
|
||||
}
|
||||
if err := json.Unmarshal(body, &objectBody); err != nil {
|
||||
return nil, fmt.Errorf("invalid request body")
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(objectBody.Names)+1)
|
||||
if strings.TrimSpace(objectBody.Name) != "" {
|
||||
out = append(out, objectBody.Name)
|
||||
}
|
||||
out = append(out, objectBody.Names...)
|
||||
return uniqueAuthFileNames(out), nil
|
||||
}
|
||||
|
||||
func uniqueAuthFileNames(names []string) []string {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(names))
|
||||
out := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string, int, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
return "", http.StatusBadRequest, fmt.Errorf("invalid name")
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
targetID := ""
|
||||
@@ -699,22 +888,19 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
}
|
||||
if errRemove := os.Remove(targetPath); errRemove != nil {
|
||||
if os.IsNotExist(errRemove) {
|
||||
c.JSON(404, gin.H{"error": "file not found"})
|
||||
} else {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)})
|
||||
return filepath.Base(name), http.StatusNotFound, errAuthFileNotFound
|
||||
}
|
||||
return
|
||||
return filepath.Base(name), http.StatusInternalServerError, fmt.Errorf("failed to remove file: %w", errRemove)
|
||||
}
|
||||
if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil {
|
||||
c.JSON(500, gin.H{"error": errDeleteRecord.Error()})
|
||||
return
|
||||
return filepath.Base(name), http.StatusInternalServerError, errDeleteRecord
|
||||
}
|
||||
if targetID != "" {
|
||||
h.disableAuth(ctx, targetID)
|
||||
} else {
|
||||
h.disableAuth(ctx, targetPath)
|
||||
}
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
return filepath.Base(name), http.StatusOK, nil
|
||||
}
|
||||
|
||||
func (h *Handler) findAuthForDelete(name string) *coreauth.Auth {
|
||||
@@ -748,10 +934,25 @@ func (h *Handler) authIDForPath(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
path = filepath.Clean(path)
|
||||
if !filepath.IsAbs(path) {
|
||||
if abs, errAbs := filepath.Abs(path); errAbs == nil {
|
||||
path = abs
|
||||
}
|
||||
}
|
||||
id := path
|
||||
if h != nil && h.cfg != nil {
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if resolvedAuthDir, errResolve := util.ResolveAuthDir(authDir); errResolve == nil && resolvedAuthDir != "" {
|
||||
authDir = resolvedAuthDir
|
||||
}
|
||||
if authDir != "" {
|
||||
authDir = filepath.Clean(authDir)
|
||||
if !filepath.IsAbs(authDir) {
|
||||
if abs, errAbs := filepath.Abs(authDir); errAbs == nil {
|
||||
authDir = abs
|
||||
}
|
||||
}
|
||||
if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
@@ -768,19 +969,27 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
if h.authManager == nil {
|
||||
return nil
|
||||
}
|
||||
auth, err := h.buildAuthFromFileData(path, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.upsertAuthRecord(ctx, auth)
|
||||
}
|
||||
|
||||
func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Auth, error) {
|
||||
if path == "" {
|
||||
return fmt.Errorf("auth path is empty")
|
||||
return nil, fmt.Errorf("auth path is empty")
|
||||
}
|
||||
if data == nil {
|
||||
var err error
|
||||
data, err = os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read auth file: %w", err)
|
||||
return nil, fmt.Errorf("failed to read auth file: %w", err)
|
||||
}
|
||||
}
|
||||
metadata := make(map[string]any)
|
||||
if err := json.Unmarshal(data, &metadata); err != nil {
|
||||
return fmt.Errorf("invalid auth file: %w", err)
|
||||
return nil, fmt.Errorf("invalid auth file: %w", err)
|
||||
}
|
||||
provider, _ := metadata["type"].(string)
|
||||
if provider == "" {
|
||||
@@ -814,13 +1023,25 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
if hasLastRefresh {
|
||||
auth.LastRefreshedAt = lastRefresh
|
||||
}
|
||||
if existing, ok := h.authManager.GetByID(authID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
if !hasLastRefresh {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
if h != nil && h.authManager != nil {
|
||||
if existing, ok := h.authManager.GetByID(authID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
if !hasLastRefresh {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
}
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
auth.Runtime = existing.Runtime
|
||||
}
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
auth.Runtime = existing.Runtime
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (h *Handler) upsertAuthRecord(ctx context.Context, auth *coreauth.Auth) error {
|
||||
if h == nil || h.authManager == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
if existing, ok := h.authManager.GetByID(auth.ID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
_, err := h.authManager.Update(ctx, auth)
|
||||
return err
|
||||
}
|
||||
|
||||
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestUploadAuthFile_BatchMultipart(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
|
||||
files := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`},
|
||||
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
for _, file := range files {
|
||||
part, err := writer.CreateFormFile("file", file.name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create multipart file: %v", err)
|
||||
}
|
||||
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||
t.Fatalf("failed to write multipart content: %v", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("failed to close multipart writer: %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
ctx.Request = req
|
||||
|
||||
h.UploadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) {
|
||||
t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"])
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
fullPath := filepath.Join(authDir, file.name)
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
t.Fatalf("expected uploaded file %s to exist: %v", file.name, err)
|
||||
}
|
||||
if string(data) != file.content {
|
||||
t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
auths := manager.List()
|
||||
if len(auths) != len(files) {
|
||||
t.Fatalf("expected %d auth entries, got %d", len(files), len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
|
||||
existingName := "alpha.json"
|
||||
existingContent := `{"type":"codex","email":"alpha@example.com"}`
|
||||
if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil {
|
||||
t.Fatalf("failed to seed existing auth file: %v", err)
|
||||
}
|
||||
|
||||
files := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{name: existingName, content: `{"type":"codex"`},
|
||||
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
for _, file := range files {
|
||||
part, err := writer.CreateFormFile("file", file.name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create multipart file: %v", err)
|
||||
}
|
||||
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||
t.Fatalf("failed to write multipart content: %v", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("failed to close multipart writer: %v", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
ctx.Request = req
|
||||
|
||||
h.UploadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusMultiStatus {
|
||||
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(authDir, existingName))
|
||||
if err != nil {
|
||||
t.Fatalf("expected existing auth file to remain readable: %v", err)
|
||||
}
|
||||
if string(data) != existingContent {
|
||||
t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data))
|
||||
}
|
||||
|
||||
betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("expected valid auth file to be created: %v", err)
|
||||
}
|
||||
if string(betaData) != files[1].content {
|
||||
t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAuthFile_BatchQuery(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
files := []string{"alpha.json", "beta.json"}
|
||||
for _, name := range files {
|
||||
if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil {
|
||||
t.Fatalf("failed to write auth file %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(
|
||||
http.MethodDelete,
|
||||
"/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]),
|
||||
nil,
|
||||
)
|
||||
ctx.Request = req
|
||||
|
||||
h.DeleteAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) {
|
||||
t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"])
|
||||
}
|
||||
|
||||
for _, name := range files {
|
||||
if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package codebuddy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
BaseURL = "https://copilot.tencent.com"
|
||||
DefaultDomain = "www.codebuddy.cn"
|
||||
UserAgent = "CLI/2.63.2 CodeBuddy/2.63.2"
|
||||
|
||||
codeBuddyStatePath = "/v2/plugin/auth/state"
|
||||
codeBuddyTokenPath = "/v2/plugin/auth/token"
|
||||
codeBuddyRefreshPath = "/v2/plugin/auth/token/refresh"
|
||||
pollInterval = 5 * time.Second
|
||||
maxPollDuration = 5 * time.Minute
|
||||
codeLoginPending = 11217
|
||||
codeSuccess = 0
|
||||
)
|
||||
|
||||
type CodeBuddyAuth struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewCodeBuddyAuth(cfg *config.Config) *CodeBuddyAuth {
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
httpClient = util.SetProxy(&cfg.SDKConfig, httpClient)
|
||||
}
|
||||
return &CodeBuddyAuth{httpClient: httpClient, cfg: cfg, baseURL: BaseURL}
|
||||
}
|
||||
|
||||
// AuthState holds the state and auth URL returned by the auth state API.
|
||||
type AuthState struct {
|
||||
State string
|
||||
AuthURL string
|
||||
}
|
||||
|
||||
// FetchAuthState calls POST /v2/plugin/auth/state?platform=CLI to get the state and login URL.
|
||||
func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) {
|
||||
stateURL := fmt.Sprintf("%s%s?platform=CLI", a.baseURL, codeBuddyStatePath)
|
||||
body := []byte("{}")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, stateURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
|
||||
}
|
||||
|
||||
requestID := uuid.NewString()
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
req.Header.Set("X-Domain", "copilot.tencent.com")
|
||||
req.Header.Set("X-No-Authorization", "true")
|
||||
req.Header.Set("X-No-User-Id", "true")
|
||||
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||
req.Header.Set("X-No-Department-Info", "true")
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
req.Header.Set("X-Request-ID", requestID)
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: auth state request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy auth state: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to read auth state response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("codebuddy: auth state request returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
State string `json:"state"`
|
||||
AuthURL string `json:"authUrl"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to parse auth state response: %w", err)
|
||||
}
|
||||
if result.Code != codeSuccess {
|
||||
return nil, fmt.Errorf("codebuddy: auth state request failed with code %d: %s", result.Code, result.Msg)
|
||||
}
|
||||
if result.Data == nil || result.Data.State == "" || result.Data.AuthURL == "" {
|
||||
return nil, fmt.Errorf("codebuddy: auth state response missing state or authUrl")
|
||||
}
|
||||
|
||||
return &AuthState{
|
||||
State: result.Data.State,
|
||||
AuthURL: result.Data.AuthURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type pollResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
RequestID string `json:"requestId"`
|
||||
Data *struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresIn int64 `json:"expiresIn"`
|
||||
TokenType string `json:"tokenType"`
|
||||
Domain string `json:"domain"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// doPollRequest performs a single polling request, safely reading and closing the response body
|
||||
func (a *CodeBuddyAuth) doPollRequest(ctx context.Context, pollURL string) ([]byte, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%w: %v", ErrTokenFetchFailed, err)
|
||||
}
|
||||
a.applyPollHeaders(req)
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy poll: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, fmt.Errorf("codebuddy poll: failed to read response body: %w", err)
|
||||
}
|
||||
return body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// PollForToken polls until the user completes browser authorization and returns auth data.
|
||||
func (a *CodeBuddyAuth) PollForToken(ctx context.Context, state string) (*CodeBuddyTokenStorage, error) {
|
||||
deadline := time.Now().Add(maxPollDuration)
|
||||
pollURL := fmt.Sprintf("%s%s?state=%s", a.baseURL, codeBuddyTokenPath, url.QueryEscape(state))
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(pollInterval):
|
||||
}
|
||||
|
||||
body, statusCode, err := a.doPollRequest(ctx, pollURL)
|
||||
if err != nil {
|
||||
log.Debugf("codebuddy poll: request error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if statusCode != http.StatusOK {
|
||||
log.Debugf("codebuddy poll: unexpected status %d", statusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
var result pollResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch result.Code {
|
||||
case codeSuccess:
|
||||
if result.Data == nil {
|
||||
return nil, fmt.Errorf("%w: empty data in response", ErrTokenFetchFailed)
|
||||
}
|
||||
userID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||
return &CodeBuddyTokenStorage{
|
||||
AccessToken: result.Data.AccessToken,
|
||||
RefreshToken: result.Data.RefreshToken,
|
||||
ExpiresIn: result.Data.ExpiresIn,
|
||||
TokenType: result.Data.TokenType,
|
||||
Domain: result.Data.Domain,
|
||||
UserID: userID,
|
||||
Type: "codebuddy",
|
||||
}, nil
|
||||
case codeLoginPending:
|
||||
// continue polling
|
||||
default:
|
||||
// TODO: when the CodeBuddy API error code for user denial is known,
|
||||
// return ErrAccessDenied here instead of ErrTokenFetchFailed.
|
||||
return nil, fmt.Errorf("%w: server returned code %d: %s", ErrTokenFetchFailed, result.Code, result.Msg)
|
||||
}
|
||||
}
|
||||
return nil, ErrPollingTimeout
|
||||
}
|
||||
|
||||
// DecodeUserID decodes the sub field from a JWT access token as the user ID.
|
||||
func (a *CodeBuddyAuth) DecodeUserID(accessToken string) (string, error) {
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) < 2 {
|
||||
return "", ErrJWTDecodeFailed
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||
}
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||
}
|
||||
if claims.Sub == "" {
|
||||
return "", fmt.Errorf("%w: sub claim is empty", ErrJWTDecodeFailed)
|
||||
}
|
||||
return claims.Sub, nil
|
||||
}
|
||||
|
||||
// RefreshToken exchanges a refresh token for a new access token.
|
||||
// It calls POST /v2/plugin/auth/token/refresh with the required headers.
|
||||
func (a *CodeBuddyAuth) RefreshToken(ctx context.Context, accessToken, refreshToken, userID, domain string) (*CodeBuddyTokenStorage, error) {
|
||||
if domain == "" {
|
||||
domain = DefaultDomain
|
||||
}
|
||||
refreshURL := fmt.Sprintf("%s%s", a.baseURL, codeBuddyRefreshPath)
|
||||
body := []byte("{}")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
requestID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
req.Header.Set("X-Domain", domain)
|
||||
req.Header.Set("X-Refresh-Token", refreshToken)
|
||||
req.Header.Set("X-Auth-Refresh-Source", "plugin")
|
||||
req.Header.Set("X-Request-ID", requestID)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("X-User-Id", userID)
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy refresh: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return nil, fmt.Errorf("codebuddy: refresh token rejected (status %d)", resp.StatusCode)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("codebuddy: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresIn int64 `json:"expiresIn"`
|
||||
RefreshExpiresIn int64 `json:"refreshExpiresIn"`
|
||||
TokenType string `json:"tokenType"`
|
||||
Domain string `json:"domain"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to parse refresh response: %w", err)
|
||||
}
|
||||
if result.Code != codeSuccess {
|
||||
return nil, fmt.Errorf("codebuddy: refresh failed with code %d: %s", result.Code, result.Msg)
|
||||
}
|
||||
if result.Data == nil {
|
||||
return nil, fmt.Errorf("codebuddy: empty data in refresh response")
|
||||
}
|
||||
|
||||
newUserID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||
if newUserID == "" {
|
||||
newUserID = userID
|
||||
}
|
||||
tokenDomain := result.Data.Domain
|
||||
if tokenDomain == "" {
|
||||
tokenDomain = domain
|
||||
}
|
||||
|
||||
return &CodeBuddyTokenStorage{
|
||||
AccessToken: result.Data.AccessToken,
|
||||
RefreshToken: result.Data.RefreshToken,
|
||||
ExpiresIn: result.Data.ExpiresIn,
|
||||
RefreshExpiresIn: result.Data.RefreshExpiresIn,
|
||||
TokenType: result.Data.TokenType,
|
||||
Domain: tokenDomain,
|
||||
UserID: newUserID,
|
||||
Type: "codebuddy",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *CodeBuddyAuth) applyPollHeaders(req *http.Request) {
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
req.Header.Set("X-No-Authorization", "true")
|
||||
req.Header.Set("X-No-User-Id", "true")
|
||||
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||
req.Header.Set("X-No-Department-Info", "true")
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
}
|
||||
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package codebuddy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// newTestAuth creates a CodeBuddyAuth pointing at the given test server.
|
||||
func newTestAuth(serverURL string) *CodeBuddyAuth {
|
||||
return &CodeBuddyAuth{
|
||||
httpClient: http.DefaultClient,
|
||||
baseURL: serverURL,
|
||||
}
|
||||
}
|
||||
|
||||
// fakeJWT builds a minimal JWT with the given sub claim for testing.
|
||||
func fakeJWT(sub string) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
|
||||
payload, _ := json.Marshal(map[string]any{"sub": sub, "iat": 1234567890})
|
||||
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||
return header + "." + encodedPayload + ".sig"
|
||||
}
|
||||
|
||||
// --- FetchAuthState tests ---
|
||||
|
||||
func TestFetchAuthState_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if got := r.URL.Path; got != codeBuddyStatePath {
|
||||
t.Errorf("expected path %s, got %s", codeBuddyStatePath, got)
|
||||
}
|
||||
if got := r.URL.Query().Get("platform"); got != "CLI" {
|
||||
t.Errorf("expected platform=CLI, got %s", got)
|
||||
}
|
||||
if got := r.Header.Get("User-Agent"); got != UserAgent {
|
||||
t.Errorf("expected User-Agent %s, got %s", UserAgent, got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"state": "test-state-abc",
|
||||
"authUrl": "https://example.com/login?state=test-state-abc",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
result, err := auth.FetchAuthState(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.State != "test-state-abc" {
|
||||
t.Errorf("expected state 'test-state-abc', got '%s'", result.State)
|
||||
}
|
||||
if result.AuthURL != "https://example.com/login?state=test-state-abc" {
|
||||
t.Errorf("unexpected authURL: %s", result.AuthURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAuthState_NonOKStatus(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("internal error"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.FetchAuthState(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-200 status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAuthState_APIErrorCode(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 10001,
|
||||
"msg": "rate limited",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.FetchAuthState(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-zero code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAuthState_MissingData(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"state": "",
|
||||
"authUrl": "",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.FetchAuthState(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty state/authUrl")
|
||||
}
|
||||
}
|
||||
|
||||
// --- RefreshToken tests ---
|
||||
|
||||
func TestRefreshToken_Success(t *testing.T) {
|
||||
newAccessToken := fakeJWT("refreshed-user-456")
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if got := r.URL.Path; got != codeBuddyRefreshPath {
|
||||
t.Errorf("expected path %s, got %s", codeBuddyRefreshPath, got)
|
||||
}
|
||||
if got := r.Header.Get("X-Refresh-Token"); got != "old-refresh-token" {
|
||||
t.Errorf("expected X-Refresh-Token 'old-refresh-token', got '%s'", got)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer old-access-token" {
|
||||
t.Errorf("expected Authorization 'Bearer old-access-token', got '%s'", got)
|
||||
}
|
||||
if got := r.Header.Get("X-User-Id"); got != "user-123" {
|
||||
t.Errorf("expected X-User-Id 'user-123', got '%s'", got)
|
||||
}
|
||||
if got := r.Header.Get("X-Domain"); got != "custom.domain.com" {
|
||||
t.Errorf("expected X-Domain 'custom.domain.com', got '%s'", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"accessToken": newAccessToken,
|
||||
"refreshToken": "new-refresh-token",
|
||||
"expiresIn": 3600,
|
||||
"refreshExpiresIn": 86400,
|
||||
"tokenType": "bearer",
|
||||
"domain": "custom.domain.com",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
storage, err := auth.RefreshToken(context.Background(), "old-access-token", "old-refresh-token", "user-123", "custom.domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if storage.AccessToken != newAccessToken {
|
||||
t.Errorf("expected new access token, got '%s'", storage.AccessToken)
|
||||
}
|
||||
if storage.RefreshToken != "new-refresh-token" {
|
||||
t.Errorf("expected 'new-refresh-token', got '%s'", storage.RefreshToken)
|
||||
}
|
||||
if storage.UserID != "refreshed-user-456" {
|
||||
t.Errorf("expected userID 'refreshed-user-456', got '%s'", storage.UserID)
|
||||
}
|
||||
if storage.ExpiresIn != 3600 {
|
||||
t.Errorf("expected expiresIn 3600, got %d", storage.ExpiresIn)
|
||||
}
|
||||
if storage.RefreshExpiresIn != 86400 {
|
||||
t.Errorf("expected refreshExpiresIn 86400, got %d", storage.RefreshExpiresIn)
|
||||
}
|
||||
if storage.Domain != "custom.domain.com" {
|
||||
t.Errorf("expected domain 'custom.domain.com', got '%s'", storage.Domain)
|
||||
}
|
||||
if storage.Type != "codebuddy" {
|
||||
t.Errorf("expected type 'codebuddy', got '%s'", storage.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_DefaultDomain(t *testing.T) {
|
||||
var receivedDomain string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedDomain = r.Header.Get("X-Domain")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"accessToken": fakeJWT("user-1"),
|
||||
"refreshToken": "rt",
|
||||
"expiresIn": 3600,
|
||||
"tokenType": "bearer",
|
||||
"domain": DefaultDomain,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if receivedDomain != DefaultDomain {
|
||||
t.Errorf("expected default domain '%s', got '%s'", DefaultDomain, receivedDomain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_Unauthorized(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 401 response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_Forbidden(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403 response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_APIErrorCode(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 40001,
|
||||
"msg": "invalid refresh token",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-zero API code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_FallbackUserIDAndDomain(t *testing.T) {
|
||||
// When the new access token cannot be decoded for userID, it should fall back to the provided one.
|
||||
// When the response domain is empty, it should fall back to the request domain.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]any{
|
||||
"accessToken": "not-a-valid-jwt",
|
||||
"refreshToken": "new-rt",
|
||||
"expiresIn": 7200,
|
||||
"tokenType": "bearer",
|
||||
"domain": "",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
auth := newTestAuth(srv.URL)
|
||||
storage, err := auth.RefreshToken(context.Background(), "at", "rt", "original-uid", "original.domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if storage.UserID != "original-uid" {
|
||||
t.Errorf("expected fallback userID 'original-uid', got '%s'", storage.UserID)
|
||||
}
|
||||
if storage.Domain != "original.domain.com" {
|
||||
t.Errorf("expected fallback domain 'original.domain.com', got '%s'", storage.Domain)
|
||||
}
|
||||
}
|
||||
22
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
22
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package codebuddy_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||
)
|
||||
|
||||
func TestDecodeUserID_ValidJWT(t *testing.T) {
|
||||
// JWT payload: {"sub":"test-user-id-123","iat":1234567890}
|
||||
// base64url encode: eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ
|
||||
token := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ.sig"
|
||||
auth := codebuddy.NewCodeBuddyAuth(nil)
|
||||
userID, err := auth.DecodeUserID(token)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if userID != "test-user-id-123" {
|
||||
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
|
||||
}
|
||||
}
|
||||
|
||||
25
internal/auth/codebuddy/errors.go
Normal file
25
internal/auth/codebuddy/errors.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package codebuddy
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrPollingTimeout = errors.New("codebuddy: polling timeout, user did not authorize in time")
|
||||
ErrAccessDenied = errors.New("codebuddy: access denied by user")
|
||||
ErrTokenFetchFailed = errors.New("codebuddy: failed to fetch token from server")
|
||||
ErrJWTDecodeFailed = errors.New("codebuddy: failed to decode JWT token")
|
||||
)
|
||||
|
||||
func GetUserFriendlyMessage(err error) string {
|
||||
switch {
|
||||
case errors.Is(err, ErrPollingTimeout):
|
||||
return "Authentication timed out. Please try again."
|
||||
case errors.Is(err, ErrAccessDenied):
|
||||
return "Access denied. Please try again and approve the login request."
|
||||
case errors.Is(err, ErrJWTDecodeFailed):
|
||||
return "Failed to decode token. Please try logging in again."
|
||||
case errors.Is(err, ErrTokenFetchFailed):
|
||||
return "Failed to fetch token from server. Please try again."
|
||||
default:
|
||||
return "Authentication failed: " + err.Error()
|
||||
}
|
||||
}
|
||||
65
internal/auth/codebuddy/token.go
Normal file
65
internal/auth/codebuddy/token.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Package codebuddy provides authentication and token management functionality
|
||||
// for CodeBuddy AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the CodeBuddy API.
|
||||
package codebuddy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// CodeBuddyTokenStorage stores OAuth token information for CodeBuddy API authentication.
|
||||
// It maintains compatibility with the existing auth system while adding CodeBuddy-specific fields
|
||||
// for managing access tokens and user account information.
|
||||
type CodeBuddyTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// ExpiresIn is the number of seconds until the access token expires.
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
// RefreshExpiresIn is the number of seconds until the refresh token expires.
|
||||
RefreshExpiresIn int64 `json:"refresh_expires_in,omitempty"`
|
||||
// TokenType is the type of token, typically "bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Domain is the CodeBuddy service domain/region.
|
||||
Domain string `json:"domain"`
|
||||
// UserID is the user ID associated with this token.
|
||||
UserID string `json:"user_id"`
|
||||
// Type indicates the authentication provider type, always "codebuddy" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the CodeBuddy token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (s *CodeBuddyTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
s.Type = "codebuddy"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(authFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(s); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
)
|
||||
|
||||
// newAuthManager creates a new authentication manager instance with all supported
|
||||
// authenticators and a file-based token store. It initializes authenticators for
|
||||
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
|
||||
// authenticators and a file-based token store.
|
||||
//
|
||||
// Returns:
|
||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||
@@ -24,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
sdkAuth.NewCodeBuddyAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
43
internal/cmd/codebuddy_login.go
Normal file
43
internal/cmd/codebuddy_login.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens.
|
||||
// It initiates the OAuth authentication, displays the user code for the user to enter
|
||||
// at the CodeBuddy verification URL, and waits for authorization before saving the tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy and auth directory settings
|
||||
// - options: Login options including browser behavior settings
|
||||
func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("CodeBuddy authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("CodeBuddy authentication successful!")
|
||||
}
|
||||
55
internal/config/claude_header_defaults_test.go
Normal file
55
internal/config/claude_header_defaults_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfigOptional_ClaudeHeaderDefaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
configYAML := []byte(`
|
||||
claude-header-defaults:
|
||||
user-agent: " claude-cli/2.1.70 (external, cli) "
|
||||
package-version: " 0.80.0 "
|
||||
runtime-version: " v24.5.0 "
|
||||
os: " MacOS "
|
||||
arch: " arm64 "
|
||||
timeout: " 900 "
|
||||
stabilize-device-profile: false
|
||||
`)
|
||||
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfigOptional(configPath, false)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||
}
|
||||
|
||||
if got := cfg.ClaudeHeaderDefaults.UserAgent; got != "claude-cli/2.1.70 (external, cli)" {
|
||||
t.Fatalf("UserAgent = %q, want %q", got, "claude-cli/2.1.70 (external, cli)")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.PackageVersion; got != "0.80.0" {
|
||||
t.Fatalf("PackageVersion = %q, want %q", got, "0.80.0")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.RuntimeVersion; got != "v24.5.0" {
|
||||
t.Fatalf("RuntimeVersion = %q, want %q", got, "v24.5.0")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.OS; got != "MacOS" {
|
||||
t.Fatalf("OS = %q, want %q", got, "MacOS")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.Arch; got != "arm64" {
|
||||
t.Fatalf("Arch = %q, want %q", got, "arm64")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.Timeout; got != "900" {
|
||||
t.Fatalf("Timeout = %q, want %q", got, "900")
|
||||
}
|
||||
if cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||
t.Fatal("StabilizeDeviceProfile = nil, want non-nil")
|
||||
}
|
||||
if got := *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile; got {
|
||||
t.Fatalf("StabilizeDeviceProfile = %v, want false", got)
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -145,13 +146,19 @@ type Config struct {
|
||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests.
|
||||
// In legacy mode, UserAgent/PackageVersion/RuntimeVersion/Timeout act as fallbacks when
|
||||
// the client omits them, while OS/Arch remain runtime-derived. When stabilized device
|
||||
// profiles are enabled, OS/Arch become the pinned platform baseline, while
|
||||
// UserAgent/PackageVersion/RuntimeVersion seed the upgradeable software fingerprint.
|
||||
type ClaudeHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
OS string `yaml:"os" json:"os"`
|
||||
Arch string `yaml:"arch" json:"arch"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
StabilizeDeviceProfile *bool `yaml:"stabilize-device-profile,omitempty" json:"stabilize-device-profile,omitempty"`
|
||||
}
|
||||
|
||||
// CodexHeaderDefaults configures fallback header values injected into Codex
|
||||
@@ -568,6 +575,10 @@ type OpenAICompatibilityModel struct {
|
||||
|
||||
// Alias is the model name alias that clients will use to reference this model.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
|
||||
// Thinking configures the thinking/reasoning capability for this model.
|
||||
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
|
||||
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
||||
@@ -694,6 +705,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Sanitize Codex header defaults.
|
||||
cfg.SanitizeCodexHeaderDefaults()
|
||||
|
||||
// Sanitize Claude header defaults.
|
||||
cfg.SanitizeClaudeHeaderDefaults()
|
||||
|
||||
// Sanitize Claude key headers
|
||||
cfg.SanitizeClaudeKeys()
|
||||
|
||||
@@ -796,6 +810,20 @@ func (cfg *Config) SanitizeCodexHeaderDefaults() {
|
||||
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||
}
|
||||
|
||||
// SanitizeClaudeHeaderDefaults trims surrounding whitespace from the
|
||||
// configured Claude fingerprint baseline values.
|
||||
func (cfg *Config) SanitizeClaudeHeaderDefaults() {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.ClaudeHeaderDefaults.UserAgent = strings.TrimSpace(cfg.ClaudeHeaderDefaults.UserAgent)
|
||||
cfg.ClaudeHeaderDefaults.PackageVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.PackageVersion)
|
||||
cfg.ClaudeHeaderDefaults.RuntimeVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.RuntimeVersion)
|
||||
cfg.ClaudeHeaderDefaults.OS = strings.TrimSpace(cfg.ClaudeHeaderDefaults.OS)
|
||||
cfg.ClaudeHeaderDefaults.Arch = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Arch)
|
||||
cfg.ClaudeHeaderDefaults.Timeout = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Timeout)
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
|
||||
@@ -88,6 +88,87 @@ func GetAntigravityModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Antigravity)
|
||||
}
|
||||
|
||||
// GetCodeBuddyModels returns the available models for CodeBuddy (Tencent).
|
||||
// These models are served through the copilot.tencent.com API.
|
||||
func GetCodeBuddyModels() []*ModelInfo {
|
||||
now := int64(1748044800) // 2025-05-24
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "glm-5.0",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "GLM-5.0",
|
||||
Description: "GLM-5.0 via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "glm-4.7",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "GLM-4.7",
|
||||
Description: "GLM-4.7 via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "minimax-m2.5",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "MiniMax M2.5",
|
||||
Description: "MiniMax M2.5 via CodeBuddy",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2.5",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "Kimi K2.5",
|
||||
Description: "Kimi K2.5 via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "deepseek-v3-2-volc",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "DeepSeek V3.2 (Volc)",
|
||||
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "hunyuan-2.0-thinking",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "tencent",
|
||||
Type: "codebuddy",
|
||||
DisplayName: "Hunyuan 2.0 Thinking",
|
||||
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
@@ -148,6 +229,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetAmazonQModels()
|
||||
case "antigravity":
|
||||
return GetAntigravityModels()
|
||||
case "codebuddy":
|
||||
return GetCodeBuddyModels()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -176,6 +259,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetKiroModels(),
|
||||
GetKiloModels(),
|
||||
GetAmazonQModels(),
|
||||
GetCodeBuddyModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
@@ -365,6 +449,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.4",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.4",
|
||||
Description: "OpenAI GPT-5.4 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
Object: "model",
|
||||
|
||||
@@ -73,16 +73,16 @@ type availableModelsCacheEntry struct {
|
||||
// Values are interpreted in provider-native token units.
|
||||
type ThinkingSupport struct {
|
||||
// Min is the minimum allowed thinking budget (inclusive).
|
||||
Min int `json:"min,omitempty"`
|
||||
Min int `json:"min,omitempty" yaml:"min,omitempty"`
|
||||
// Max is the maximum allowed thinking budget (inclusive).
|
||||
Max int `json:"max,omitempty"`
|
||||
Max int `json:"max,omitempty" yaml:"max,omitempty"`
|
||||
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
|
||||
ZeroAllowed bool `json:"zero_allowed,omitempty"`
|
||||
ZeroAllowed bool `json:"zero_allowed,omitempty" yaml:"zero-allowed,omitempty"`
|
||||
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
|
||||
DynamicAllowed bool `json:"dynamic_allowed,omitempty"`
|
||||
DynamicAllowed bool `json:"dynamic_allowed,omitempty" yaml:"dynamic-allowed,omitempty"`
|
||||
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
|
||||
// When set, the model uses level-based reasoning instead of token budgets.
|
||||
Levels []string `json:"levels,omitempty"`
|
||||
Levels []string `json:"levels,omitempty" yaml:"levels,omitempty"`
|
||||
}
|
||||
|
||||
// ModelRegistration tracks a model's availability
|
||||
|
||||
383
internal/runtime/executor/claude_device_profile.go
Normal file
383
internal/runtime/executor/claude_device_profile.go
Normal file
@@ -0,0 +1,383 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClaudeFingerprintUserAgent = "claude-cli/2.1.63 (external, cli)"
|
||||
defaultClaudeFingerprintPackageVersion = "0.74.0"
|
||||
defaultClaudeFingerprintRuntimeVersion = "v24.3.0"
|
||||
defaultClaudeFingerprintOS = "MacOS"
|
||||
defaultClaudeFingerprintArch = "arm64"
|
||||
claudeDeviceProfileTTL = 7 * 24 * time.Hour
|
||||
claudeDeviceProfileCleanupPeriod = time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
claudeCLIVersionPattern = regexp.MustCompile(`^claude-cli/(\d+)\.(\d+)\.(\d+)`)
|
||||
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu sync.RWMutex
|
||||
claudeDeviceProfileCacheCleanupOnce sync.Once
|
||||
|
||||
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
|
||||
)
|
||||
|
||||
type claudeCLIVersion struct {
|
||||
major int
|
||||
minor int
|
||||
patch int
|
||||
}
|
||||
|
||||
func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
|
||||
switch {
|
||||
case v.major != other.major:
|
||||
if v.major > other.major {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
case v.minor != other.minor:
|
||||
if v.minor > other.minor {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
case v.patch != other.patch:
|
||||
if v.patch > other.patch {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
type claudeDeviceProfile struct {
|
||||
UserAgent string
|
||||
PackageVersion string
|
||||
RuntimeVersion string
|
||||
OS string
|
||||
Arch string
|
||||
Version claudeCLIVersion
|
||||
HasVersion bool
|
||||
}
|
||||
|
||||
type claudeDeviceProfileCacheEntry struct {
|
||||
profile claudeDeviceProfile
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
||||
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||
return false
|
||||
}
|
||||
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
|
||||
}
|
||||
|
||||
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if strings.TrimSpace(cfgVal) != "" {
|
||||
return strings.TrimSpace(cfgVal)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hd config.ClaudeHeaderDefaults
|
||||
if cfg != nil {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
profile := claudeDeviceProfile{
|
||||
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
|
||||
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
|
||||
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
|
||||
OS: hdrDefault(hd.OS, defaultClaudeFingerprintOS),
|
||||
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
|
||||
}
|
||||
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||
profile.Version = version
|
||||
profile.HasVersion = true
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
|
||||
matches := claudeCLIVersionPattern.FindStringSubmatch(strings.TrimSpace(userAgent))
|
||||
if len(matches) != 4 {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
major, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
minor, err := strconv.Atoi(matches[2])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
patch, err := strconv.Atoi(matches[3])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
|
||||
}
|
||||
|
||||
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
|
||||
if candidate.UserAgent == "" || !candidate.HasVersion {
|
||||
return false
|
||||
}
|
||||
if current.UserAgent == "" || !current.HasVersion {
|
||||
return true
|
||||
}
|
||||
return candidate.Version.Compare(current.Version) > 0
|
||||
}
|
||||
|
||||
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||
profile.OS = baseline.OS
|
||||
profile.Arch = baseline.Arch
|
||||
return profile
|
||||
}
|
||||
|
||||
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
|
||||
// baseline platform and enforces the baseline software fingerprint as a floor.
|
||||
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
|
||||
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
||||
profile.UserAgent = baseline.UserAgent
|
||||
profile.PackageVersion = baseline.PackageVersion
|
||||
profile.RuntimeVersion = baseline.RuntimeVersion
|
||||
profile.Version = baseline.Version
|
||||
profile.HasVersion = baseline.HasVersion
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
|
||||
if headers == nil {
|
||||
return claudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
||||
version, ok := parseClaudeCLIVersion(userAgent)
|
||||
if !ok {
|
||||
return claudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
profile := claudeDeviceProfile{
|
||||
UserAgent: userAgent,
|
||||
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
|
||||
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
|
||||
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
|
||||
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
|
||||
Version: version,
|
||||
HasVersion: true,
|
||||
}
|
||||
return profile, true
|
||||
}
|
||||
|
||||
func firstNonEmptyHeader(headers http.Header, name, fallback string) string {
|
||||
if headers == nil {
|
||||
return fallback
|
||||
}
|
||||
if value := strings.TrimSpace(headers.Get(name)); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func claudeDeviceProfileScopeKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||
switch {
|
||||
case auth != nil && strings.TrimSpace(auth.ID) != "":
|
||||
return "auth:" + strings.TrimSpace(auth.ID)
|
||||
case strings.TrimSpace(apiKey) != "":
|
||||
return "api_key:" + strings.TrimSpace(apiKey)
|
||||
default:
|
||||
return "global"
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDeviceProfileCacheKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(claudeDeviceProfileScopeKey(auth, apiKey)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func startClaudeDeviceProfileCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(claudeDeviceProfileCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredClaudeDeviceProfiles()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredClaudeDeviceProfiles() {
|
||||
now := time.Now()
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
for key, entry := range claudeDeviceProfileCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(claudeDeviceProfileCache, key)
|
||||
}
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
|
||||
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
|
||||
|
||||
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
|
||||
now := time.Now()
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
candidate, hasCandidate := extractClaudeDeviceProfile(headers, cfg)
|
||||
if hasCandidate {
|
||||
candidate = pinClaudeDeviceProfilePlatform(candidate, baseline)
|
||||
}
|
||||
if hasCandidate && !shouldUpgradeClaudeDeviceProfile(candidate, baseline) {
|
||||
hasCandidate = false
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.RLock()
|
||||
entry, hasCached := claudeDeviceProfileCache[cacheKey]
|
||||
cachedValid := hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||
claudeDeviceProfileCacheMu.RUnlock()
|
||||
|
||||
if hasCandidate {
|
||||
if claudeDeviceProfileBeforeCandidateStore != nil {
|
||||
claudeDeviceProfileBeforeCandidateStore(candidate)
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
entry, hasCached = claudeDeviceProfileCache[cacheKey]
|
||||
cachedValid = hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||
if cachedValid {
|
||||
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||
}
|
||||
if cachedValid && !shouldUpgradeClaudeDeviceProfile(candidate, entry.profile) {
|
||||
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||
claudeDeviceProfileCache[cacheKey] = entry
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return entry.profile
|
||||
}
|
||||
|
||||
claudeDeviceProfileCache[cacheKey] = claudeDeviceProfileCacheEntry{
|
||||
profile: candidate,
|
||||
expire: now.Add(claudeDeviceProfileTTL),
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return candidate
|
||||
}
|
||||
|
||||
if cachedValid {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
entry = claudeDeviceProfileCache[cacheKey]
|
||||
if entry.expire.After(now) && entry.profile.UserAgent != "" {
|
||||
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||
claudeDeviceProfileCache[cacheKey] = entry
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return entry.profile
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
return baseline
|
||||
}
|
||||
|
||||
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
for _, headerName := range []string{
|
||||
"User-Agent",
|
||||
"X-Stainless-Package-Version",
|
||||
"X-Stainless-Runtime-Version",
|
||||
"X-Stainless-Os",
|
||||
"X-Stainless-Arch",
|
||||
} {
|
||||
r.Header.Del(headerName)
|
||||
}
|
||||
r.Header.Set("User-Agent", profile.UserAgent)
|
||||
r.Header.Set("X-Stainless-Package-Version", profile.PackageVersion)
|
||||
r.Header.Set("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||
r.Header.Set("X-Stainless-Os", profile.OS)
|
||||
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
||||
}
|
||||
|
||||
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
profile := defaultClaudeDeviceProfile(cfg)
|
||||
miscEnsure := func(name, fallback string) {
|
||||
if strings.TrimSpace(r.Header.Get(name)) != "" {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ginHeaders.Get(name)) != "" {
|
||||
r.Header.Set(name, strings.TrimSpace(ginHeaders.Get(name)))
|
||||
return
|
||||
}
|
||||
r.Header.Set(name, fallback)
|
||||
}
|
||||
|
||||
miscEnsure("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||
miscEnsure("X-Stainless-Package-Version", profile.PackageVersion)
|
||||
miscEnsure("X-Stainless-Os", mapStainlessOS())
|
||||
miscEnsure("X-Stainless-Arch", mapStainlessArch())
|
||||
|
||||
// Legacy mode preserves per-auth custom header overrides. By the time we get
|
||||
// here, ApplyCustomHeadersFromAttrs has already populated r.Header.
|
||||
if strings.TrimSpace(r.Header.Get("User-Agent")) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = strings.TrimSpace(ginHeaders.Get("User-Agent"))
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
return
|
||||
}
|
||||
r.Header.Set("User-Agent", profile.UserAgent)
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -767,36 +766,6 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if cfgVal != "" {
|
||||
@@ -824,6 +793,11 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
|
||||
var deviceProfile claudeDeviceProfile
|
||||
if stabilizeDeviceProfile {
|
||||
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
|
||||
}
|
||||
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
@@ -867,25 +841,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
// For User-Agent, only forward the client's header if it's already a Claude Code client.
|
||||
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
|
||||
// to avoid leaking the real client identity during cloaking.
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = ginHeaders.Get("User-Agent")
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
} else {
|
||||
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
|
||||
}
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
@@ -897,13 +855,19 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
r.Header.Set("Accept", "application/json")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
// Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
|
||||
// to the configured baseline while still allowing newer official
|
||||
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
if stabilizeDeviceProfile {
|
||||
applyClaudeDeviceProfileHeaders(r, deviceProfile)
|
||||
} else {
|
||||
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
|
||||
}
|
||||
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||
|
||||
@@ -8,8 +8,11 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -19,6 +22,587 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func resetClaudeDeviceProfileCache() {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginReq := httptest.NewRequest(http.MethodPost, "http://localhost/v1/messages", nil)
|
||||
ginReq.Header = incoming.Clone()
|
||||
ginCtx.Request = ginReq
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
|
||||
return req.WithContext(context.WithValue(req.Context(), "gin", ginCtx))
|
||||
}
|
||||
|
||||
func assertClaudeFingerprint(t *testing.T, headers http.Header, userAgent, pkgVersion, runtimeVersion, osName, arch string) {
|
||||
t.Helper()
|
||||
|
||||
if got := headers.Get("User-Agent"); got != userAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", got, userAgent)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Package-Version"); got != pkgVersion {
|
||||
t.Fatalf("X-Stainless-Package-Version = %q, want %q", got, pkgVersion)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Runtime-Version"); got != runtimeVersion {
|
||||
t.Fatalf("X-Stainless-Runtime-Version = %q, want %q", got, runtimeVersion)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Os"); got != osName {
|
||||
t.Fatalf("X-Stainless-Os = %q, want %q", got, osName)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Arch"); got != arch {
|
||||
t.Fatalf("X-Stainless-Arch = %q, want %q", got, arch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
Timeout: "900",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline",
|
||||
"header:User-Agent": "evil-client/9.9",
|
||||
"header:X-Stainless-Os": "Linux",
|
||||
"header:X-Stainless-Arch": "x64",
|
||||
"header:X-Stainless-Package-Version": "9.9.9",
|
||||
},
|
||||
}
|
||||
incoming := http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, incoming)
|
||||
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
|
||||
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_TracksHighestClaudeCLIFingerprint(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-upgrade",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
firstReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(firstReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"lobe-chat/1.0"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
|
||||
|
||||
higherReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||
"X-Stainless-Os": []string{"MacOS"},
|
||||
"X-Stainless-Arch": []string{"arm64"},
|
||||
})
|
||||
applyClaudeHeaders(higherReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, higherReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
|
||||
|
||||
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.73.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(lowerReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_DoesNotDowngradeConfiguredBaselineOnFirstClaudeClient(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline-floor",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline-floor",
|
||||
},
|
||||
}
|
||||
|
||||
olderClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(olderClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, olderClaudeReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
|
||||
newerClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.81.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(newerClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, newerClaudeReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UpgradesCachedSoftwareFingerprintWhenBaselineAdvances(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
oldCfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.77 (external, cli)",
|
||||
PackageVersion: "0.87.0",
|
||||
RuntimeVersion: "v24.8.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline-reload",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline-reload",
|
||||
},
|
||||
}
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.81.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-baseline-reload", false, nil, oldCfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-baseline-reload", false, nil, newCfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LearnsOfficialFingerprintAfterCustomBaselineFallback(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "my-gateway/1.0",
|
||||
PackageVersion: "custom-pkg",
|
||||
RuntimeVersion: "custom-runtime",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-custom-baseline-learning",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-custom-baseline-learning",
|
||||
},
|
||||
}
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "my-gateway/1.0", "custom-pkg", "custom-runtime", "MacOS", "arm64")
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.87.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
|
||||
postLearningThirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(postLearningThirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, postLearningThirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-racy-upgrade",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-racy-upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
lowPaused := make(chan struct{})
|
||||
releaseLow := make(chan struct{})
|
||||
var pauseOnce sync.Once
|
||||
var releaseOnce sync.Once
|
||||
|
||||
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
|
||||
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
|
||||
return
|
||||
}
|
||||
pauseOnce.Do(func() { close(lowPaused) })
|
||||
<-releaseLow
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
claudeDeviceProfileBeforeCandidateStore = nil
|
||||
releaseOnce.Do(func() { close(releaseLow) })
|
||||
})
|
||||
|
||||
lowResultCh := make(chan claudeDeviceProfile, 1)
|
||||
go func() {
|
||||
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
}, cfg)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lowPaused:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for lower candidate to pause before storing")
|
||||
}
|
||||
|
||||
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||
"X-Stainless-Os": []string{"MacOS"},
|
||||
"X-Stainless-Arch": []string{"arm64"},
|
||||
}, cfg)
|
||||
releaseOnce.Do(func() { close(releaseLow) })
|
||||
|
||||
select {
|
||||
case lowResult := <-lowResultCh:
|
||||
if lowResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("lowResult.UserAgent = %q, want %q", lowResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if lowResult.PackageVersion != "0.75.0" {
|
||||
t.Fatalf("lowResult.PackageVersion = %q, want %q", lowResult.PackageVersion, "0.75.0")
|
||||
}
|
||||
if lowResult.OS != "MacOS" || lowResult.Arch != "arm64" {
|
||||
t.Fatalf("lowResult platform = %s/%s, want %s/%s", lowResult.OS, lowResult.Arch, "MacOS", "arm64")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for lower candidate result")
|
||||
}
|
||||
|
||||
if highResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("highResult.UserAgent = %q, want %q", highResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if highResult.OS != "MacOS" || highResult.Arch != "arm64" {
|
||||
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
|
||||
}
|
||||
|
||||
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
}, cfg)
|
||||
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("cached.UserAgent = %q, want %q", cached.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if cached.PackageVersion != "0.75.0" {
|
||||
t.Fatalf("cached.PackageVersion = %q, want %q", cached.PackageVersion, "0.75.0")
|
||||
}
|
||||
if cached.OS != "MacOS" || cached.Arch != "arm64" {
|
||||
t.Fatalf("cached platform = %s/%s, want %s/%s", cached.OS, cached.Arch, "MacOS", "arm64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_ThirdPartyBaselineThenOfficialUpgradeKeepsPinnedPlatform(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-third-party-then-official",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-third-party-then-official",
|
||||
},
|
||||
}
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-third-party-then-official", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.87.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-third-party-then-official", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_DisableDeviceProfileStabilization(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-disable-stability",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-disable-stability",
|
||||
},
|
||||
}
|
||||
|
||||
firstReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(firstReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "Linux", "x64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"lobe-chat/1.0"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.60 (external, cli)", "0.10.0", "v18.0.0", "Windows", "x64")
|
||||
|
||||
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.73.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(lowerReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.61 (external, cli)", "0.73.0", "v24.2.0", "Windows", "x64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LegacyModePreservesConfiguredUserAgentOverrideForClaudeClients(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-legacy-ua-override",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-legacy-ua-override",
|
||||
"header:User-Agent": "config-ua/1.0",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-legacy-ua-override", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "config-ua/1.0", "0.74.0", "v24.3.0", "Linux", "x64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-legacy-runtime-os-arch",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-legacy-runtime-os-arch",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-unset-runtime-os-arch",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-unset-runtime-os-arch",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||
}
|
||||
|
||||
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
|
||||
if claudeDeviceProfileStabilizationEnabled(nil) {
|
||||
t.Fatal("expected nil config to default to disabled stabilization")
|
||||
}
|
||||
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
|
||||
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
343
internal/runtime/executor/codebuddy_executor.go
Normal file
343
internal/runtime/executor/codebuddy_executor.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codeBuddyChatPath = "/v2/chat/completions"
|
||||
codeBuddyAuthType = "codebuddy"
|
||||
)
|
||||
|
||||
// CodeBuddyExecutor handles requests to the CodeBuddy API.
|
||||
type CodeBuddyExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewCodeBuddyExecutor creates a new CodeBuddy executor instance.
|
||||
func NewCodeBuddyExecutor(cfg *config.Config) *CodeBuddyExecutor {
|
||||
return &CodeBuddyExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (e *CodeBuddyExecutor) Identifier() string { return codeBuddyAuthType }
|
||||
|
||||
// codeBuddyCredentials extracts the access token and domain from auth metadata.
|
||||
func codeBuddyCredentials(auth *cliproxyauth.Auth) (accessToken, userID, domain string) {
|
||||
if auth == nil {
|
||||
return "", "", ""
|
||||
}
|
||||
accessToken = metaStringValue(auth.Metadata, "access_token")
|
||||
userID = metaStringValue(auth.Metadata, "user_id")
|
||||
domain = metaStringValue(auth.Metadata, "domain")
|
||||
if domain == "" {
|
||||
domain = codebuddy.DefaultDomain
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PrepareRequest prepares the HTTP request before execution.
|
||||
func (e *CodeBuddyExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
e.applyHeaders(req, accessToken, userID, domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest executes a raw HTTP request.
|
||||
func (e *CodeBuddyExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("codebuddy executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request.
|
||||
func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return resp, fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := codebuddy.BaseURL + codeBuddyChatPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request.
|
||||
func (e *CodeBuddyExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := codebuddy.BaseURL + codeBuddyChatPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
httpResp.Body.Close()
|
||||
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy executor: close stream body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, maxScannerBufferSize)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: httpResp.Header.Clone(),
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh exchanges the CodeBuddy refresh token for a new access token.
|
||||
func (e *CodeBuddyExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("codebuddy: missing auth")
|
||||
}
|
||||
|
||||
refreshToken := metaStringValue(auth.Metadata, "refresh_token")
|
||||
if refreshToken == "" {
|
||||
log.Debugf("codebuddy executor: no refresh token available, skipping refresh")
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
|
||||
authSvc := codebuddy.NewCodeBuddyAuth(e.cfg)
|
||||
storage, err := authSvc.RefreshToken(ctx, accessToken, refreshToken, userID, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: token refresh failed: %w", err)
|
||||
}
|
||||
|
||||
updated := auth.Clone()
|
||||
updated.Metadata["access_token"] = storage.AccessToken
|
||||
if storage.RefreshToken != "" {
|
||||
updated.Metadata["refresh_token"] = storage.RefreshToken
|
||||
}
|
||||
updated.Metadata["expires_in"] = storage.ExpiresIn
|
||||
updated.Metadata["domain"] = storage.Domain
|
||||
if storage.UserID != "" {
|
||||
updated.Metadata["user_id"] = storage.UserID
|
||||
}
|
||||
now := time.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.LastRefreshedAt = now
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// CountTokens is not supported for CodeBuddy.
|
||||
func (e *CodeBuddyExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("codebuddy: count tokens not supported")
|
||||
}
|
||||
|
||||
// applyHeaders sets required headers for CodeBuddy API requests.
|
||||
func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID, domain string) {
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", codebuddy.UserAgent)
|
||||
req.Header.Set("X-User-Id", userID)
|
||||
req.Header.Set("X-Domain", domain)
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
req.Header.Set("X-IDE-Type", "CLI")
|
||||
req.Header.Set("X-IDE-Name", "CLI")
|
||||
req.Header.Set("X-IDE-Version", "2.63.2")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
}
|
||||
@@ -28,8 +28,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.101.0"
|
||||
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexOriginator = "codex_cli_rs"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
@@ -645,8 +645,10 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
|
||||
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
@@ -663,8 +665,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
isAPIKey = true
|
||||
}
|
||||
}
|
||||
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
|
||||
r.Header.Set("Originator", originator)
|
||||
} else if !isAPIKey {
|
||||
r.Header.Set("Originator", codexOriginator)
|
||||
}
|
||||
if !isAPIKey {
|
||||
r.Header.Set("Originator", "codex_cli_rs")
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
||||
r.Header.Set("Chatgpt-Account-Id", accountID)
|
||||
|
||||
@@ -814,9 +814,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "Version", "")
|
||||
|
||||
misc.EnsureHeader(headers, ginHeaders, "Version", codexClientVersion)
|
||||
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
|
||||
if betaHeader == "" && ginHeaders != nil {
|
||||
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
|
||||
@@ -834,8 +835,12 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
isAPIKey = true
|
||||
}
|
||||
}
|
||||
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
|
||||
headers.Set("Originator", originator)
|
||||
} else if !isAPIKey {
|
||||
headers.Set("Originator", codexOriginator)
|
||||
}
|
||||
if !isAPIKey {
|
||||
headers.Set("Originator", "codex_cli_rs")
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
|
||||
|
||||
@@ -41,9 +41,46 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T)
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
}
|
||||
if got := headers.Get("Version"); got != "" {
|
||||
t.Fatalf("Version = %q, want empty", got)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
if got := headers.Get("X-Codex-Turn-Metadata"); got != "" {
|
||||
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
|
||||
}
|
||||
if got := headers.Get("X-Client-Request-Id"); got != "" {
|
||||
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
ctx := contextWithGinHeaders(map[string]string{
|
||||
"Originator": "Codex Desktop",
|
||||
"Version": "0.115.0-alpha.27",
|
||||
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
|
||||
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
|
||||
})
|
||||
|
||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
|
||||
|
||||
if got := headers.Get("Originator"); got != "Codex Desktop" {
|
||||
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
|
||||
}
|
||||
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
|
||||
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
|
||||
}
|
||||
if got := headers.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
|
||||
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
|
||||
}
|
||||
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
|
||||
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
||||
@@ -177,6 +214,57 @@ func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest() error = %v", err)
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
req = req.WithContext(contextWithGinHeaders(map[string]string{
|
||||
"Originator": "Codex Desktop",
|
||||
"Version": "0.115.0-alpha.27",
|
||||
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
|
||||
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
|
||||
}))
|
||||
|
||||
applyCodexHeaders(req, auth, "oauth-token", true, nil)
|
||||
|
||||
if got := req.Header.Get("Originator"); got != "Codex Desktop" {
|
||||
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
|
||||
}
|
||||
if got := req.Header.Get("Version"); got != "0.115.0-alpha.27" {
|
||||
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
|
||||
}
|
||||
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
|
||||
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
|
||||
}
|
||||
if got := req.Header.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
|
||||
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexHeadersDoesNotInjectClientOnlyHeadersByDefault(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest() error = %v", err)
|
||||
}
|
||||
|
||||
applyCodexHeaders(req, nil, "oauth-token", true, nil)
|
||||
|
||||
if got := req.Header.Get("Version"); got != "" {
|
||||
t.Fatalf("Version = %q, want empty", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != "" {
|
||||
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Client-Request-Id"); got != "" {
|
||||
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func contextWithGinHeaders(headers map[string]string) context.Context {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -577,9 +577,33 @@ func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model
|
||||
return true
|
||||
}
|
||||
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
|
||||
if info := registry.GetGlobalRegistry().GetModelInfo(baseModel, githubCopilotAuthType); info != nil {
|
||||
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
|
||||
}
|
||||
if info := lookupGitHubCopilotStaticModelInfo(baseModel); info != nil {
|
||||
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
|
||||
}
|
||||
return strings.Contains(baseModel, "codex")
|
||||
}
|
||||
|
||||
func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
|
||||
for _, info := range registry.GetStaticModelDefinitionsByChannel(githubCopilotAuthType) {
|
||||
if info != nil && strings.EqualFold(info.ID, model) {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func containsEndpoint(endpoints []string, endpoint string) bool {
|
||||
for _, item := range endpoints {
|
||||
if item == endpoint {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// flattenAssistantContent converts assistant message content from array format
|
||||
// to a joined string. GitHub Copilot requires assistant content as a string;
|
||||
// sending it as an array causes Claude models to re-answer all previous prompts.
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -70,6 +71,29 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected responses-only registry model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "github-copilot-test-client"
|
||||
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
|
||||
ID: "gpt-5.4",
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
}})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||
|
||||
@@ -73,17 +73,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Failed: failed,
|
||||
Detail: detail,
|
||||
})
|
||||
usage.PublishRecord(ctx, r.buildRecord(detail, failed))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -96,20 +86,39 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Failed: false,
|
||||
Detail: usage.Detail{},
|
||||
})
|
||||
usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
||||
if r == nil {
|
||||
return usage.Record{Detail: detail, Failed: failed}
|
||||
}
|
||||
return usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Latency: r.latency(),
|
||||
Failed: failed,
|
||||
Detail: detail,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageReporter) latency() time.Duration {
|
||||
if r == nil || r.requestedAt.IsZero() {
|
||||
return 0
|
||||
}
|
||||
latency := time.Since(r.requestedAt)
|
||||
if latency < 0 {
|
||||
return 0
|
||||
}
|
||||
return latency
|
||||
}
|
||||
|
||||
func apiKeyFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package executor
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
||||
@@ -41,3 +46,19 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
|
||||
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
||||
reporter := &usageReporter{
|
||||
provider: "openai",
|
||||
model: "gpt-5.4",
|
||||
requestedAt: time.Now().Add(-1500 * time.Millisecond),
|
||||
}
|
||||
|
||||
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
|
||||
if record.Latency < time.Second {
|
||||
t.Fatalf("latency = %v, want >= 1s", record.Latency)
|
||||
}
|
||||
if record.Latency > 3*time.Second {
|
||||
t.Fatalf("latency = %v, want <= 3s", record.Latency)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,11 +148,11 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
|
||||
// Valid signature, send as thought block
|
||||
// Always include "text" field — Google Antigravity API requires it
|
||||
// even for redacted thinking where the text is empty.
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
|
||||
if thinkingText != "" {
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
|
||||
}
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
@@ -171,7 +171,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||
// Antigravity API validates signatures, so dummy values are rejected.
|
||||
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
|
||||
argsResult := contentResult.Get("input")
|
||||
functionID := contentResult.Get("id").String()
|
||||
|
||||
@@ -233,7 +233,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
functionResponseJSON := []byte(`{}`)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", funcName)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", util.SanitizeFunctionName(funcName))
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
@@ -398,6 +398,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
|
||||
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
|
||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||
for toolKey := range gjson.ParseBytes(tool).Map() {
|
||||
if util.InArray(allowedToolKeys, toolKey) {
|
||||
continue
|
||||
@@ -471,7 +472,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
case "tool":
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,10 @@ type Params struct {
|
||||
|
||||
// Signature caching support
|
||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||
|
||||
// Reverse map: sanitized Gemini function name → original Claude tool name.
|
||||
// Populated lazily on the first response chunk from the original request JSON.
|
||||
ToolNameMap map[string]string
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -71,6 +75,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
@@ -212,7 +217,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude Code API compatibility
|
||||
params.HasToolUse = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName(params.ToolNameMap, functionCallResult.Get("name").String())
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
@@ -348,7 +353,7 @@ func resolveStopReason(params *Params) string {
|
||||
// Returns:
|
||||
// - []byte: A Claude-compatible JSON response.
|
||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
_ = originalRequestRawJSON
|
||||
toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
@@ -450,7 +455,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
flushText()
|
||||
hasToolCall = true
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
|
||||
toolIDCounter++
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
|
||||
@@ -286,7 +286,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
@@ -309,7 +309,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
@@ -384,7 +384,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
fnRaw = string(fnRawBytes)
|
||||
}
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
fnRawBytes := []byte(fnRaw)
|
||||
fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String()))
|
||||
fnRaw, _ = sjson.Delete(string(fnRawBytes), "strict")
|
||||
if !hasFunction {
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
@@ -26,6 +27,7 @@ type convertCliResponseToOpenAIChatParams struct {
|
||||
FunctionIndex int
|
||||
SawToolCall bool // Tracks if any tool call was seen in the entire stream
|
||||
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -48,10 +50,14 @@ var functionCallIDCounter uint64
|
||||
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &convertCliResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return [][]byte{}
|
||||
@@ -159,7 +165,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
}
|
||||
|
||||
functionCallTemplate := []byte(`{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
|
||||
@@ -36,6 +36,18 @@ type ToolCallAccumulator struct {
|
||||
Arguments strings.Builder
|
||||
}
|
||||
|
||||
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
completionTokens = usage.Get("output_tokens").Int()
|
||||
cachedTokens = usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
|
||||
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
|
||||
totalTokens = promptTokens + completionTokens
|
||||
|
||||
return promptTokens, completionTokens, totalTokens, cachedTokens
|
||||
}
|
||||
|
||||
// ConvertClaudeResponseToOpenAI converts Claude Code streaming response format to OpenAI Chat Completions format.
|
||||
// This function processes various Claude Code event types and transforms them into OpenAI-compatible JSON responses.
|
||||
// It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match
|
||||
@@ -207,14 +219,11 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
|
||||
// Handle usage information for token counts
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", inputTokens+outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||
}
|
||||
return [][]byte{template}
|
||||
|
||||
@@ -366,14 +375,11 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
}
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", inputTokens+outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
out := ConvertClaudeResponseToOpenAI(
|
||||
ctx,
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
nil,
|
||||
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":13,"output_tokens":4,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}`),
|
||||
¶m,
|
||||
)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||
}
|
||||
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||
}
|
||||
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||
}
|
||||
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
|
||||
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
|
||||
|
||||
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
|
||||
|
||||
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||
}
|
||||
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||
}
|
||||
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||
}
|
||||
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
|
||||
|
||||
@@ -45,3 +45,48 @@ func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.
|
||||
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToOpenAI_ToolCallChunkOmitsNullContentFields(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), ¶m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
|
||||
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
|
||||
}
|
||||
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
|
||||
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
|
||||
}
|
||||
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls").Exists() {
|
||||
t.Fatalf("expected tool_calls to exist, got %s", string(out[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToOpenAI_ToolCallArgumentsDeltaOmitsNullContentFields(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), ¶m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected tool call announcement chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
out = ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.function_call_arguments.delta","delta":"{\"query\":\"OpenAI\"}"}`), ¶m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
|
||||
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
|
||||
}
|
||||
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
|
||||
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
|
||||
}
|
||||
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls.0.function.arguments").Exists() {
|
||||
t.Fatalf("expected tool call arguments delta to exist, got %s", string(out[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package responses
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -39,6 +40,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
|
||||
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
|
||||
rawJSON = convertSystemRoleToDeveloper(rawJSON)
|
||||
rawJSON = normalizeCodexBuiltinTools(rawJSON)
|
||||
|
||||
return rawJSON
|
||||
}
|
||||
@@ -82,3 +84,59 @@ func convertSystemRoleToDeveloper(rawJSON []byte) []byte {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// normalizeCodexBuiltinTools rewrites legacy/preview built-in tool variants to the
|
||||
// stable names expected by the current Codex upstream.
|
||||
func normalizeCodexBuiltinTools(rawJSON []byte) []byte {
|
||||
result := rawJSON
|
||||
|
||||
tools := gjson.GetBytes(result, "tools")
|
||||
if tools.IsArray() {
|
||||
toolArray := tools.Array()
|
||||
for i := 0; i < len(toolArray); i++ {
|
||||
typePath := fmt.Sprintf("tools.%d.type", i)
|
||||
result = normalizeCodexBuiltinToolAtPath(result, typePath)
|
||||
}
|
||||
}
|
||||
|
||||
result = normalizeCodexBuiltinToolAtPath(result, "tool_choice.type")
|
||||
|
||||
toolChoiceTools := gjson.GetBytes(result, "tool_choice.tools")
|
||||
if toolChoiceTools.IsArray() {
|
||||
toolArray := toolChoiceTools.Array()
|
||||
for i := 0; i < len(toolArray); i++ {
|
||||
typePath := fmt.Sprintf("tool_choice.tools.%d.type", i)
|
||||
result = normalizeCodexBuiltinToolAtPath(result, typePath)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeCodexBuiltinToolAtPath(rawJSON []byte, path string) []byte {
|
||||
currentType := gjson.GetBytes(rawJSON, path).String()
|
||||
normalizedType := normalizeCodexBuiltinToolType(currentType)
|
||||
if normalizedType == "" {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
updated, err := sjson.SetBytes(rawJSON, path, normalizedType)
|
||||
if err != nil {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
log.Debugf("codex responses: normalized builtin tool type at %s from %q to %q", path, currentType, normalizedType)
|
||||
return updated
|
||||
}
|
||||
|
||||
// normalizeCodexBuiltinToolType centralizes the current known Codex Responses
|
||||
// built-in tool alias compatibility. If Codex introduces more legacy aliases,
|
||||
// extend this helper instead of adding path-specific rewrite logic elsewhere.
|
||||
func normalizeCodexBuiltinToolType(toolType string) string {
|
||||
switch toolType {
|
||||
case "web_search_preview", "web_search_preview_2025_03_11":
|
||||
return "web_search"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,6 +264,52 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIResponsesRequestToCodex_NormalizesWebSearchPreview(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.4-mini",
|
||||
"input": "find latest OpenAI model news",
|
||||
"tools": [
|
||||
{"type": "web_search_preview_2025_03_11"}
|
||||
],
|
||||
"tool_choice": {
|
||||
"type": "allowed_tools",
|
||||
"tools": [
|
||||
{"type": "web_search_preview"},
|
||||
{"type": "web_search_preview_2025_03_11"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
|
||||
|
||||
if got := gjson.GetBytes(output, "tools.0.type").String(); got != "web_search" {
|
||||
t.Fatalf("tools.0.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "allowed_tools" {
|
||||
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "allowed_tools", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.tools.0.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.tools.0.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.tools.1.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.tools.1.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIResponsesRequestToCodex_NormalizesTopLevelToolChoicePreviewAlias(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.4-mini",
|
||||
"input": "find latest OpenAI model news",
|
||||
"tool_choice": {"type": "web_search_preview_2025_03_11"}
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
|
||||
|
||||
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
|
||||
@@ -89,7 +89,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
|
||||
case "tool_use":
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() && gjson.Valid(functionArgs) {
|
||||
@@ -112,7 +112,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.name", util.SanitizeFunctionName(funcName))
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
|
||||
@@ -151,6 +151,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
|
||||
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
|
||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||
tool, _ = sjson.DeleteBytes(tool, "strict")
|
||||
tool, _ = sjson.DeleteBytes(tool, "input_examples")
|
||||
tool, _ = sjson.DeleteBytes(tool, "type")
|
||||
@@ -194,7 +195,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
case "tool":
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,9 @@ type Params struct {
|
||||
ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function
|
||||
ResponseIndex int // Index counter for content blocks in the streaming response
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
|
||||
// Reverse map: sanitized Gemini function name → original Claude tool name.
|
||||
ToolNameMap map[string]string
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -55,6 +58,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +169,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude Code API compatibility
|
||||
usedTool = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName((*param).(*Params).ToolNameMap, functionCallResult.Get("name").String())
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
@@ -248,7 +252,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// Returns:
|
||||
// - []byte: A Claude-compatible JSON response.
|
||||
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
_ = originalRequestRawJSON
|
||||
toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
_ = requestRawJSON
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
@@ -306,7 +310,7 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
flushText()
|
||||
hasToolCall = true
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
|
||||
toolIDCounter++
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
|
||||
@@ -251,7 +251,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
@@ -268,7 +268,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
@@ -331,6 +331,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
continue
|
||||
}
|
||||
}
|
||||
fnRaw, _ = sjson.SetBytes(fnRaw, "name", util.SanitizeFunctionName(fn.Get("name").String()))
|
||||
fnRaw, _ = sjson.DeleteBytes(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -21,8 +22,9 @@ import (
|
||||
|
||||
// convertCliResponseToOpenAIChatParams holds parameters for response conversion.
|
||||
type convertCliResponseToOpenAIChatParams struct {
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -45,10 +47,14 @@ var functionCallIDCounter uint64
|
||||
func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &convertCliResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return [][]byte{}
|
||||
@@ -163,7 +169,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
}
|
||||
|
||||
functionCallTemplate := []byte(`{"id":"","index":0,"type":"function","function":{"name":"","arguments":""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -90,6 +91,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
functionName = derived
|
||||
}
|
||||
}
|
||||
functionName = util.SanitizeFunctionName(functionName)
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() && gjson.Valid(functionArgs) {
|
||||
@@ -109,6 +111,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
if funcName == "" {
|
||||
funcName = toolCallID
|
||||
}
|
||||
funcName = util.SanitizeFunctionName(funcName)
|
||||
responseData := contentResult.Get("content").Raw
|
||||
part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.name", funcName)
|
||||
@@ -165,6 +168,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
tool, _ = sjson.DeleteBytes(tool, "type")
|
||||
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
||||
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
|
||||
if !hasTools {
|
||||
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[{"functionDeclarations":[]}]`))
|
||||
@@ -202,7 +206,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
case "tool":
|
||||
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ type Params struct {
|
||||
ResponseIndex int
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
ToolNameMap map[string]string
|
||||
SanitizedNameMap map[string]string
|
||||
SawToolCall bool
|
||||
}
|
||||
|
||||
@@ -57,6 +58,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.ToolNameMapFromClaudeRequest(originalRequestRawJSON),
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
SawToolCall: false,
|
||||
}
|
||||
}
|
||||
@@ -167,6 +169,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
// This processes tool usage requests and formats them for Claude API compatibility
|
||||
(*param).(*Params).SawToolCall = true
|
||||
upstreamToolName := functionCallResult.Get("name").String()
|
||||
upstreamToolName = util.RestoreSanitizedToolName((*param).(*Params).SanitizedNameMap, upstreamToolName)
|
||||
clientToolName := util.MapToolName((*param).(*Params).ToolNameMap, upstreamToolName)
|
||||
|
||||
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
|
||||
@@ -260,6 +263,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
|
||||
sanitizedNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
|
||||
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
out, _ = sjson.SetBytes(out, "id", root.Get("responseId").String())
|
||||
@@ -315,6 +319,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
hasToolCall = true
|
||||
|
||||
upstreamToolName := functionCall.Get("name").String()
|
||||
upstreamToolName = util.RestoreSanitizedToolName(sanitizedNameMap, upstreamToolName)
|
||||
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
|
||||
toolIDCounter++
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
|
||||
@@ -257,7 +257,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
@@ -274,7 +274,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
@@ -341,6 +341,9 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
fnRaw = string(fnRawBytes)
|
||||
}
|
||||
fnRawBytes := []byte(fnRaw)
|
||||
fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String()))
|
||||
fnRaw = string(fnRawBytes)
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -22,7 +23,8 @@ import (
|
||||
type convertGeminiResponseToOpenAIChatParams struct {
|
||||
UnixTimestamp int64
|
||||
// FunctionIndex tracks tool call indices per candidate index to support multiple candidates.
|
||||
FunctionIndex map[int]int
|
||||
FunctionIndex map[int]int
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -46,8 +48,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
// Initialize parameters if nil.
|
||||
if *param == nil {
|
||||
*param = &convertGeminiResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: make(map[int]int),
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: make(map[int]int),
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +59,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
if p.FunctionIndex == nil {
|
||||
p.FunctionIndex = make(map[int]int)
|
||||
}
|
||||
if p.SanitizedNameMap == nil {
|
||||
p.SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
@@ -191,7 +197,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
}
|
||||
|
||||
functionCallTemplate := []byte(`{"id":"","index":0,"type":"function","function":{"name":"","arguments":""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName(p.SanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
@@ -265,6 +271,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
// Returns:
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
sanitizedNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
var unixTimestamp int64
|
||||
// Initialize template with an empty choices array to support multiple candidates.
|
||||
template := []byte(`{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[]}`)
|
||||
@@ -358,7 +365,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
choiceTemplate, _ = sjson.SetRawBytes(choiceTemplate, "message.tool_calls", []byte(`[]`))
|
||||
}
|
||||
functionCallItemTemplate := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName(sanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -291,7 +292,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
|
||||
case "function_call":
|
||||
// Handle function calls - convert to model message with functionCall
|
||||
name := item.Get("name").String()
|
||||
name := util.SanitizeFunctionName(item.Get("name").String())
|
||||
arguments := item.Get("arguments").String()
|
||||
|
||||
modelContent := []byte(`{"role":"model","parts":[]}`)
|
||||
@@ -333,6 +334,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
return true
|
||||
})
|
||||
}
|
||||
functionName = util.SanitizeFunctionName(functionName)
|
||||
|
||||
functionResponse, _ = sjson.SetBytes(functionResponse, "functionResponse.name", functionName)
|
||||
functionResponse, _ = sjson.SetBytes(functionResponse, "functionResponse.id", callID)
|
||||
@@ -375,7 +377,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
funcDecl := []byte(`{"name":"","description":"","parametersJsonSchema":{}}`)
|
||||
|
||||
if name := tool.Get("name"); name.Exists() {
|
||||
funcDecl, _ = sjson.SetBytes(funcDecl, "name", name.String())
|
||||
funcDecl, _ = sjson.SetBytes(funcDecl, "name", util.SanitizeFunctionName(name.String()))
|
||||
}
|
||||
if desc := tool.Get("description"); desc.Exists() {
|
||||
funcDecl, _ = sjson.SetBytes(funcDecl, "description", desc.String())
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -36,11 +37,12 @@ type geminiToResponsesState struct {
|
||||
ReasoningClosed bool
|
||||
|
||||
// function call aggregation (keyed by output_index)
|
||||
NextIndex int
|
||||
FuncArgsBuf map[int]*strings.Builder
|
||||
FuncNames map[int]string
|
||||
FuncCallIDs map[int]string
|
||||
FuncDone map[int]bool
|
||||
NextIndex int
|
||||
FuncArgsBuf map[int]*strings.Builder
|
||||
FuncNames map[int]string
|
||||
FuncCallIDs map[int]string
|
||||
FuncDone map[int]bool
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
|
||||
@@ -90,10 +92,11 @@ func emitEvent(event string, payload []byte) []byte {
|
||||
func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &geminiToResponsesState{
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncDone: make(map[int]bool),
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncDone: make(map[int]bool),
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
st := (*param).(*geminiToResponsesState)
|
||||
@@ -109,6 +112,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
if st.FuncDone == nil {
|
||||
st.FuncDone = make(map[int]bool)
|
||||
}
|
||||
if st.SanitizedNameMap == nil {
|
||||
st.SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
@@ -306,7 +312,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// Responses streaming requires message done events before the next output_item.added.
|
||||
finalizeReasoning()
|
||||
finalizeMessage()
|
||||
name := fc.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(st.SanitizedNameMap, fc.Get("name").String())
|
||||
idx := st.NextIndex
|
||||
st.NextIndex++
|
||||
// Ensure buffers
|
||||
@@ -565,6 +571,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
root = unwrapGeminiResponseRoot(root)
|
||||
sanitizedNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
|
||||
// Base response scaffold
|
||||
resp := []byte(`{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`)
|
||||
@@ -694,7 +701,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
return true
|
||||
}
|
||||
if fc := p.Get("functionCall"); fc.Exists() {
|
||||
name := fc.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(sanitizedNameMap, fc.Get("name").String())
|
||||
args := fc.Get("args")
|
||||
callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
|
||||
itemJSON := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
|
||||
@@ -87,9 +87,10 @@ type modelStats struct {
|
||||
Details []RequestDetail
|
||||
}
|
||||
|
||||
// RequestDetail stores the timestamp and token usage for a single request.
|
||||
// RequestDetail stores the timestamp, latency, and token usage for a single request.
|
||||
type RequestDetail struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
Source string `json:"source"`
|
||||
AuthIndex string `json:"auth_index"`
|
||||
Tokens TokenStats `json:"tokens"`
|
||||
@@ -198,6 +199,7 @@ func (s *RequestStatistics) Record(ctx context.Context, record coreusage.Record)
|
||||
}
|
||||
s.updateAPIStats(stats, modelName, RequestDetail{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: normaliseLatency(record.Latency),
|
||||
Source: record.Source,
|
||||
AuthIndex: record.AuthIndex,
|
||||
Tokens: detail,
|
||||
@@ -332,6 +334,9 @@ func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResu
|
||||
}
|
||||
for _, detail := range modelSnapshot.Details {
|
||||
detail.Tokens = normaliseTokenStats(detail.Tokens)
|
||||
if detail.LatencyMs < 0 {
|
||||
detail.LatencyMs = 0
|
||||
}
|
||||
if detail.Timestamp.IsZero() {
|
||||
detail.Timestamp = time.Now()
|
||||
}
|
||||
@@ -463,6 +468,13 @@ func normaliseTokenStats(tokens TokenStats) TokenStats {
|
||||
return tokens
|
||||
}
|
||||
|
||||
func normaliseLatency(latency time.Duration) int64 {
|
||||
if latency <= 0 {
|
||||
return 0
|
||||
}
|
||||
return latency.Milliseconds()
|
||||
}
|
||||
|
||||
func formatHour(hour int) string {
|
||||
if hour < 0 {
|
||||
hour = 0
|
||||
|
||||
96
internal/usage/logger_plugin_test.go
Normal file
96
internal/usage/logger_plugin_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestRequestStatisticsRecordIncludesLatency(t *testing.T) {
|
||||
stats := NewRequestStatistics()
|
||||
stats.Record(context.Background(), coreusage.Record{
|
||||
APIKey: "test-key",
|
||||
Model: "gpt-5.4",
|
||||
RequestedAt: time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC),
|
||||
Latency: 1500 * time.Millisecond,
|
||||
Detail: coreusage.Detail{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
})
|
||||
|
||||
snapshot := stats.Snapshot()
|
||||
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
|
||||
if len(details) != 1 {
|
||||
t.Fatalf("details len = %d, want 1", len(details))
|
||||
}
|
||||
if details[0].LatencyMs != 1500 {
|
||||
t.Fatalf("latency_ms = %d, want 1500", details[0].LatencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestStatisticsMergeSnapshotDedupIgnoresLatency(t *testing.T) {
|
||||
stats := NewRequestStatistics()
|
||||
timestamp := time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC)
|
||||
first := StatisticsSnapshot{
|
||||
APIs: map[string]APISnapshot{
|
||||
"test-key": {
|
||||
Models: map[string]ModelSnapshot{
|
||||
"gpt-5.4": {
|
||||
Details: []RequestDetail{{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: 0,
|
||||
Source: "user@example.com",
|
||||
AuthIndex: "0",
|
||||
Tokens: TokenStats{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
second := StatisticsSnapshot{
|
||||
APIs: map[string]APISnapshot{
|
||||
"test-key": {
|
||||
Models: map[string]ModelSnapshot{
|
||||
"gpt-5.4": {
|
||||
Details: []RequestDetail{{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: 2500,
|
||||
Source: "user@example.com",
|
||||
AuthIndex: "0",
|
||||
Tokens: TokenStats{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := stats.MergeSnapshot(first)
|
||||
if result.Added != 1 || result.Skipped != 0 {
|
||||
t.Fatalf("first merge = %+v, want added=1 skipped=0", result)
|
||||
}
|
||||
|
||||
result = stats.MergeSnapshot(second)
|
||||
if result.Added != 0 || result.Skipped != 1 {
|
||||
t.Fatalf("second merge = %+v, want added=0 skipped=1", result)
|
||||
}
|
||||
|
||||
snapshot := stats.Snapshot()
|
||||
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
|
||||
if len(details) != 1 {
|
||||
t.Fatalf("details len = %d, want 1", len(details))
|
||||
}
|
||||
}
|
||||
@@ -54,3 +54,77 @@ func TestSanitizeFunctionName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizedToolNameMap(t *testing.T) {
|
||||
t.Run("returns map for tools needing sanitization", func(t *testing.T) {
|
||||
raw := []byte(`{"tools":[
|
||||
{"name":"valid_tool","input_schema":{}},
|
||||
{"name":"mcp/server/read","input_schema":{}},
|
||||
{"name":"tool@v2","input_schema":{}}
|
||||
]}`)
|
||||
m := SanitizedToolNameMap(raw)
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil map")
|
||||
}
|
||||
if m["mcp_server_read"] != "mcp/server/read" {
|
||||
t.Errorf("expected mcp_server_read → mcp/server/read, got %q", m["mcp_server_read"])
|
||||
}
|
||||
if m["tool_v2"] != "tool@v2" {
|
||||
t.Errorf("expected tool_v2 → tool@v2, got %q", m["tool_v2"])
|
||||
}
|
||||
if _, exists := m["valid_tool"]; exists {
|
||||
t.Error("valid_tool should not be in the map (no sanitization needed)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when no tools need sanitization", func(t *testing.T) {
|
||||
raw := []byte(`{"tools":[{"name":"Read","input_schema":{}},{"name":"Write","input_schema":{}}]}`)
|
||||
m := SanitizedToolNameMap(raw)
|
||||
if m != nil {
|
||||
t.Errorf("expected nil, got %v", m)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil for empty/missing tools", func(t *testing.T) {
|
||||
if m := SanitizedToolNameMap([]byte(`{}`)); m != nil {
|
||||
t.Error("expected nil for no tools")
|
||||
}
|
||||
if m := SanitizedToolNameMap(nil); m != nil {
|
||||
t.Error("expected nil for nil input")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("collision keeps first mapping", func(t *testing.T) {
|
||||
raw := []byte(`{"tools":[
|
||||
{"name":"read/file","input_schema":{}},
|
||||
{"name":"read@file","input_schema":{}}
|
||||
]}`)
|
||||
m := SanitizedToolNameMap(raw)
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil map")
|
||||
}
|
||||
if m["read_file"] != "read/file" {
|
||||
t.Errorf("expected first mapping read/file, got %q", m["read_file"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRestoreSanitizedToolName(t *testing.T) {
|
||||
m := map[string]string{
|
||||
"mcp_server_read": "mcp/server/read",
|
||||
"tool_v2": "tool@v2",
|
||||
}
|
||||
|
||||
if got := RestoreSanitizedToolName(m, "mcp_server_read"); got != "mcp/server/read" {
|
||||
t.Errorf("expected mcp/server/read, got %q", got)
|
||||
}
|
||||
if got := RestoreSanitizedToolName(m, "unknown"); got != "unknown" {
|
||||
t.Errorf("expected passthrough for unknown, got %q", got)
|
||||
}
|
||||
if got := RestoreSanitizedToolName(nil, "name"); got != "name" {
|
||||
t.Errorf("expected passthrough for nil map, got %q", got)
|
||||
}
|
||||
if got := RestoreSanitizedToolName(m, ""); got != "" {
|
||||
t.Errorf("expected empty for empty name, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -243,6 +244,9 @@ func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string {
|
||||
out := make(map[string]string, len(toolResults))
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
name := strings.TrimSpace(tool.Get("name").String())
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(tool.Get("function.name").String())
|
||||
}
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
@@ -271,3 +275,54 @@ func MapToolName(toolNameMap map[string]string, name string) string {
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// SanitizedToolNameMap builds a sanitized-name → original-name map from Claude request tools.
|
||||
// It is used to restore exact tool names for clients (e.g. Claude Code) after the proxy
|
||||
// sanitizes tool names for Gemini/Vertex API compatibility via SanitizeFunctionName.
|
||||
// Only entries where sanitization actually changes the name are included.
|
||||
func SanitizedToolNameMap(rawJSON []byte) map[string]string {
|
||||
if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) {
|
||||
return nil
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(map[string]string)
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
name := strings.TrimSpace(tool.Get("name").String())
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
sanitized := SanitizeFunctionName(name)
|
||||
if sanitized == name {
|
||||
return true
|
||||
}
|
||||
if _, exists := out[sanitized]; !exists {
|
||||
out[sanitized] = name
|
||||
} else {
|
||||
log.Warnf("sanitized tool name collision: %q and %q both map to %q, keeping first", out[sanitized], name, sanitized)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RestoreSanitizedToolName looks up a sanitized function name in the provided map
|
||||
// and returns the original client-facing name. If no mapping exists, it returns
|
||||
// the sanitized name unchanged.
|
||||
func RestoreSanitizedToolName(toolNameMap map[string]string, sanitizedName string) string {
|
||||
if sanitizedName == "" || toolNameMap == nil {
|
||||
return sanitizedName
|
||||
}
|
||||
if original, ok := toolNameMap[sanitizedName]; ok {
|
||||
return original
|
||||
}
|
||||
return sanitizedName
|
||||
}
|
||||
|
||||
@@ -75,7 +75,12 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
cacheAuthContents := log.IsLevelEnabled(log.DebugLevel)
|
||||
if cacheAuthContents {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
} else {
|
||||
w.lastAuthContents = nil
|
||||
}
|
||||
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
@@ -89,10 +94,12 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
sum := sha256.Sum256(data)
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
// Parse and cache auth content for future diff comparisons
|
||||
var auth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
// Parse and cache auth content for future diff comparisons (debug only).
|
||||
if cacheAuthContents {
|
||||
var auth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
}
|
||||
}
|
||||
ctx := &synthesizer.SynthesisContext{
|
||||
Config: cfg,
|
||||
@@ -102,7 +109,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
}
|
||||
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
|
||||
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
|
||||
w.fileAuthsByPath[normalizedPath] = pathAuths
|
||||
w.fileAuthsByPath[normalizedPath] = authIDSet(pathAuths)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -171,25 +178,30 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
}
|
||||
|
||||
// Get old auth for diff comparison
|
||||
cacheAuthContents := log.IsLevelEnabled(log.DebugLevel)
|
||||
var oldAuth *coreauth.Auth
|
||||
if w.lastAuthContents != nil {
|
||||
if cacheAuthContents && w.lastAuthContents != nil {
|
||||
oldAuth = w.lastAuthContents[normalized]
|
||||
}
|
||||
|
||||
// Compute and log field changes
|
||||
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
|
||||
log.Debugf("auth field changes for %s:", filepath.Base(path))
|
||||
for _, c := range changes {
|
||||
log.Debugf(" %s", c)
|
||||
if cacheAuthContents {
|
||||
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
|
||||
log.Debugf("auth field changes for %s:", filepath.Base(path))
|
||||
for _, c := range changes {
|
||||
log.Debugf(" %s", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update caches
|
||||
w.lastAuthHashes[normalized] = curHash
|
||||
if w.lastAuthContents == nil {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
if cacheAuthContents {
|
||||
if w.lastAuthContents == nil {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
|
||||
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
|
||||
for id, a := range w.fileAuthsByPath[normalized] {
|
||||
@@ -206,7 +218,7 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
|
||||
newByID := authSliceToMap(generated)
|
||||
if len(newByID) > 0 {
|
||||
w.fileAuthsByPath[normalized] = newByID
|
||||
w.fileAuthsByPath[normalized] = authIDSet(newByID)
|
||||
} else {
|
||||
delete(w.fileAuthsByPath, normalized)
|
||||
}
|
||||
@@ -273,6 +285,14 @@ func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
|
||||
return byID
|
||||
}
|
||||
|
||||
func authIDSet(auths map[string]*coreauth.Auth) map[string]*coreauth.Auth {
|
||||
set := make(map[string]*coreauth.Auth, len(auths))
|
||||
for id := range auths {
|
||||
set[id] = nil
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||
authFileCount := 0
|
||||
successfulAuthCount := 0
|
||||
|
||||
@@ -340,12 +340,13 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
}
|
||||
}
|
||||
newCtx, cancel := context.WithCancel(parentCtx)
|
||||
cancelCtx := newCtx
|
||||
if requestCtx != nil && requestCtx != parentCtx {
|
||||
go func() {
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
cancel()
|
||||
case <-newCtx.Done():
|
||||
case <-cancelCtx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
95
sdk/auth/codebuddy.go
Normal file
95
sdk/auth/codebuddy.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CodeBuddyAuthenticator implements the browser OAuth polling flow for CodeBuddy.
|
||||
type CodeBuddyAuthenticator struct{}
|
||||
|
||||
// NewCodeBuddyAuthenticator constructs a new CodeBuddy authenticator.
|
||||
func NewCodeBuddyAuthenticator() Authenticator {
|
||||
return &CodeBuddyAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for codebuddy.
|
||||
func (CodeBuddyAuthenticator) Provider() string {
|
||||
return "codebuddy"
|
||||
}
|
||||
|
||||
// codeBuddyRefreshLead is the duration before token expiry when a refresh should be attempted.
|
||||
var codeBuddyRefreshLead = 24 * time.Hour
|
||||
|
||||
// RefreshLead returns how soon before expiry a refresh should be attempted.
|
||||
// CodeBuddy tokens have a long validity period, so we refresh 24 hours before expiry.
|
||||
func (CodeBuddyAuthenticator) RefreshLead() *time.Duration {
|
||||
return &codeBuddyRefreshLead
|
||||
}
|
||||
|
||||
// Login initiates the browser OAuth flow for CodeBuddy.
|
||||
func (a CodeBuddyAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("codebuddy: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
authSvc := codebuddy.NewCodeBuddyAuth(cfg)
|
||||
|
||||
authState, err := authSvc.FetchAuthState(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to fetch auth state: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\nPlease open the following URL in your browser to login:\n\n %s\n\n", authState.AuthURL)
|
||||
fmt.Println("Waiting for authorization...")
|
||||
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(authState.AuthURL); errOpen != nil {
|
||||
log.Debugf("codebuddy: failed to open browser: %v", errOpen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := authSvc.PollForToken(ctx, authState.State)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: %s: %w", codebuddy.GetUserFriendlyMessage(err), err)
|
||||
}
|
||||
|
||||
fmt.Printf("\nSuccessfully logged in! (User ID: %s)\n", storage.UserID)
|
||||
|
||||
authID := fmt.Sprintf("codebuddy-%s.json", storage.UserID)
|
||||
|
||||
label := storage.UserID
|
||||
if label == "" {
|
||||
label = "codebuddy-user"
|
||||
}
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: authID,
|
||||
Provider: a.Provider(),
|
||||
FileName: authID,
|
||||
Label: label,
|
||||
Storage: storage,
|
||||
Metadata: map[string]any{
|
||||
"access_token": storage.AccessToken,
|
||||
"refresh_token": storage.RefreshToken,
|
||||
"user_id": storage.UserID,
|
||||
"domain": storage.Domain,
|
||||
"expires_in": storage.ExpiresIn,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -18,6 +18,7 @@ func init() {
|
||||
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
|
||||
registerRefreshLead("codebuddy", func() Authenticator { return NewCodeBuddyAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
@@ -421,10 +421,6 @@ func preserveRequestedModelSuffix(requestedModel, resolved string) string {
|
||||
}
|
||||
|
||||
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
|
||||
return m.prepareExecutionModels(auth, routeModel)
|
||||
}
|
||||
|
||||
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
|
||||
requestedModel := rewriteModelForAuth(routeModel, auth)
|
||||
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
|
||||
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
|
||||
@@ -441,6 +437,46 @@ func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string
|
||||
return []string{resolved}
|
||||
}
|
||||
|
||||
func executionResultModel(routeModel, upstreamModel string, pooled bool) string {
|
||||
if pooled {
|
||||
if resolved := strings.TrimSpace(upstreamModel); resolved != "" {
|
||||
return resolved
|
||||
}
|
||||
}
|
||||
if requested := strings.TrimSpace(routeModel); requested != "" {
|
||||
return requested
|
||||
}
|
||||
return strings.TrimSpace(upstreamModel)
|
||||
}
|
||||
|
||||
func filterExecutionModels(auth *Auth, routeModel string, candidates []string, pooled bool) []string {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
out := make([]string, 0, len(candidates))
|
||||
for _, upstreamModel := range candidates {
|
||||
stateModel := executionResultModel(routeModel, upstreamModel, pooled)
|
||||
blocked, _, _ := isAuthBlockedForModel(auth, stateModel, now)
|
||||
if blocked {
|
||||
continue
|
||||
}
|
||||
out = append(out, upstreamModel)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) preparedExecutionModels(auth *Auth, routeModel string) ([]string, bool) {
|
||||
candidates := m.executionModelCandidates(auth, routeModel)
|
||||
pooled := len(candidates) > 1
|
||||
return filterExecutionModels(auth, routeModel, candidates, pooled), pooled
|
||||
}
|
||||
|
||||
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
|
||||
models, _ := m.preparedExecutionModels(auth, routeModel)
|
||||
return models
|
||||
}
|
||||
|
||||
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
|
||||
if ch == nil {
|
||||
return
|
||||
@@ -451,6 +487,59 @@ func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
|
||||
}()
|
||||
}
|
||||
|
||||
type streamBootstrapError struct {
|
||||
cause error
|
||||
headers http.Header
|
||||
}
|
||||
|
||||
func cloneHTTPHeader(headers http.Header) http.Header {
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
return headers.Clone()
|
||||
}
|
||||
|
||||
func newStreamBootstrapError(err error, headers http.Header) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &streamBootstrapError{
|
||||
cause: err,
|
||||
headers: cloneHTTPHeader(headers),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *streamBootstrapError) Error() string {
|
||||
if e == nil || e.cause == nil {
|
||||
return ""
|
||||
}
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
func (e *streamBootstrapError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.cause
|
||||
}
|
||||
|
||||
func (e *streamBootstrapError) Headers() http.Header {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return cloneHTTPHeader(e.headers)
|
||||
}
|
||||
|
||||
func streamErrorResult(headers http.Header, err error) *cliproxyexecutor.StreamResult {
|
||||
ch := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
ch <- cliproxyexecutor.StreamChunk{Err: err}
|
||||
close(ch)
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: cloneHTTPHeader(headers),
|
||||
Chunks: ch,
|
||||
}
|
||||
}
|
||||
|
||||
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
|
||||
if ch == nil {
|
||||
return nil, true, nil
|
||||
@@ -483,7 +572,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
|
||||
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, resultModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
@@ -496,7 +585,7 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
return false
|
||||
@@ -526,19 +615,19 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro
|
||||
}
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: true})
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
|
||||
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, execModels []string, pooled bool) (*cliproxyexecutor.StreamResult, error) {
|
||||
if executor == nil {
|
||||
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
execModels := m.prepareExecutionModels(auth, routeModel)
|
||||
var lastErr error
|
||||
for idx, execModel := range execModels {
|
||||
resultModel := executionResultModel(routeModel, execModel, pooled)
|
||||
execReq := req
|
||||
execReq.Model = execModel
|
||||
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
|
||||
@@ -550,7 +639,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(ctx, result)
|
||||
if isRequestInvalidError(errStream) {
|
||||
@@ -571,7 +660,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||
m.MarkResult(ctx, result)
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
@@ -582,31 +671,33 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||
m.MarkResult(ctx, result)
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
lastErr = bootstrapErr
|
||||
continue
|
||||
}
|
||||
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
|
||||
close(errCh)
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
|
||||
rerr := &Error{Message: bootstrapErr.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||
m.MarkResult(ctx, result)
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
return nil, newStreamBootstrapError(bootstrapErr, streamResult.Headers)
|
||||
}
|
||||
|
||||
if closed && len(buffered) == 0 {
|
||||
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: emptyErr}
|
||||
m.MarkResult(ctx, result)
|
||||
if idx < len(execModels)-1 {
|
||||
lastErr = emptyErr
|
||||
continue
|
||||
}
|
||||
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
|
||||
close(errCh)
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
|
||||
return nil, newStreamBootstrapError(emptyErr, streamResult.Headers)
|
||||
}
|
||||
|
||||
remaining := streamResult.Chunks
|
||||
@@ -615,7 +706,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
|
||||
close(closedCh)
|
||||
remaining = closedCh
|
||||
}
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, resultModel, streamResult.Headers, buffered, remaining), nil
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
|
||||
@@ -979,9 +1070,10 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
attempted := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
@@ -1006,13 +1098,18 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
|
||||
models := m.prepareExecutionModels(auth, routeModel)
|
||||
models, pooled := m.preparedExecutionModels(auth, routeModel)
|
||||
if len(models) == 0 {
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
var authErr error
|
||||
for _, upstreamModel := range models {
|
||||
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
|
||||
execReq := req
|
||||
execReq.Model = upstreamModel
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
@@ -1051,9 +1148,10 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
attempted := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
@@ -1078,13 +1176,18 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
|
||||
models := m.prepareExecutionModels(auth, routeModel)
|
||||
models, pooled := m.preparedExecutionModels(auth, routeModel)
|
||||
if len(models) == 0 {
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
var authErr error
|
||||
for _, upstreamModel := range models {
|
||||
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
|
||||
execReq := req
|
||||
execReq.Model = upstreamModel
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
@@ -1096,14 +1199,14 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.hook.OnResult(execCtx, result)
|
||||
m.MarkResult(execCtx, result)
|
||||
if isRequestInvalidError(errExec) {
|
||||
return cliproxyexecutor.Response{}, errExec
|
||||
}
|
||||
authErr = errExec
|
||||
continue
|
||||
}
|
||||
m.hook.OnResult(execCtx, result)
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
if authErr != nil {
|
||||
@@ -1123,10 +1226,15 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
attempted := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
var bootstrapErr *streamBootstrapError
|
||||
if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil {
|
||||
return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
@@ -1134,6 +1242,10 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
var bootstrapErr *streamBootstrapError
|
||||
if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil {
|
||||
return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
@@ -1149,7 +1261,12 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
|
||||
models, pooled := m.preparedExecutionModels(auth, routeModel)
|
||||
if len(models) == 0 {
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
@@ -1627,53 +1744,60 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
}
|
||||
|
||||
statusCode := statusCodeFromResult(result.Error)
|
||||
switch statusCode {
|
||||
case 401:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "unauthorized"
|
||||
shouldSuspendModel = true
|
||||
case 402, 403:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "payment_required"
|
||||
shouldSuspendModel = true
|
||||
case 404:
|
||||
if isModelSupportResultError(result.Error) {
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "not_found"
|
||||
suspendReason = "model_not_supported"
|
||||
shouldSuspendModel = true
|
||||
case 429:
|
||||
var next time.Time
|
||||
backoffLevel := state.Quota.BackoffLevel
|
||||
if result.RetryAfter != nil {
|
||||
next = now.Add(*result.RetryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
}
|
||||
backoffLevel = nextLevel
|
||||
}
|
||||
state.NextRetryAfter = next
|
||||
state.Quota = QuotaState{
|
||||
Exceeded: true,
|
||||
Reason: "quota",
|
||||
NextRecoverAt: next,
|
||||
BackoffLevel: backoffLevel,
|
||||
}
|
||||
suspendReason = "quota"
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
case 408, 500, 502, 503, 504:
|
||||
if quotaCooldownDisabledForAuth(auth) {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
} else {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "unauthorized"
|
||||
shouldSuspendModel = true
|
||||
case 402, 403:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "payment_required"
|
||||
shouldSuspendModel = true
|
||||
case 404:
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "not_found"
|
||||
shouldSuspendModel = true
|
||||
case 429:
|
||||
var next time.Time
|
||||
backoffLevel := state.Quota.BackoffLevel
|
||||
if result.RetryAfter != nil {
|
||||
next = now.Add(*result.RetryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
}
|
||||
backoffLevel = nextLevel
|
||||
}
|
||||
state.NextRetryAfter = next
|
||||
state.Quota = QuotaState{
|
||||
Exceeded: true,
|
||||
Reason: "quota",
|
||||
NextRecoverAt: next,
|
||||
BackoffLevel: backoffLevel,
|
||||
}
|
||||
suspendReason = "quota"
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
case 408, 500, 502, 503, 504:
|
||||
if quotaCooldownDisabledForAuth(auth) {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
}
|
||||
default:
|
||||
state.NextRetryAfter = time.Time{}
|
||||
}
|
||||
default:
|
||||
state.NextRetryAfter = time.Time{}
|
||||
}
|
||||
|
||||
auth.Status = StatusError
|
||||
@@ -1883,14 +2007,65 @@ func statusCodeFromResult(err *Error) int {
|
||||
return err.StatusCode()
|
||||
}
|
||||
|
||||
func isModelSupportErrorMessage(message string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(message))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
patterns := [...]string{
|
||||
"model_not_supported",
|
||||
"requested model is not supported",
|
||||
"requested model is unsupported",
|
||||
"requested model is unavailable",
|
||||
"model is not supported",
|
||||
"model not supported",
|
||||
"unsupported model",
|
||||
"model unavailable",
|
||||
"not available for your plan",
|
||||
"not available for your account",
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(lower, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isModelSupportError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
status := statusCodeFromError(err)
|
||||
if status != http.StatusBadRequest && status != http.StatusUnprocessableEntity {
|
||||
return false
|
||||
}
|
||||
return isModelSupportErrorMessage(err.Error())
|
||||
}
|
||||
|
||||
func isModelSupportResultError(err *Error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
status := statusCodeFromResult(err)
|
||||
if status != http.StatusBadRequest && status != http.StatusUnprocessableEntity {
|
||||
return false
|
||||
}
|
||||
return isModelSupportErrorMessage(err.Message)
|
||||
}
|
||||
|
||||
// isRequestInvalidError returns true if the error represents a client request
|
||||
// error that should not be retried. Specifically, it treats 400 responses with
|
||||
// "invalid_request_error" and all 422 responses as request-shape failures,
|
||||
// where switching auths or pooled upstream models will not help.
|
||||
// where switching auths or pooled upstream models will not help. Model-support
|
||||
// errors are excluded so routing can fall through to another auth or upstream.
|
||||
func isRequestInvalidError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if isModelSupportError(err) {
|
||||
return false
|
||||
}
|
||||
status := statusCodeFromError(err)
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
|
||||
@@ -108,6 +108,76 @@ func (e *credentialRetryLimitExecutor) Calls() int {
|
||||
return e.calls
|
||||
}
|
||||
|
||||
type authFallbackExecutor struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
executeCalls []string
|
||||
streamCalls []string
|
||||
executeErrors map[string]error
|
||||
streamFirstErrors map[string]error
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) Identifier() string {
|
||||
return e.id
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) Execute(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
e.mu.Lock()
|
||||
e.executeCalls = append(e.executeCalls, auth.ID)
|
||||
err := e.executeErrors[auth.ID]
|
||||
e.mu.Unlock()
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
return cliproxyexecutor.Response{Payload: []byte(auth.ID)}, nil
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) ExecuteStream(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
e.streamCalls = append(e.streamCalls, auth.ID)
|
||||
err := e.streamFirstErrors[auth.ID]
|
||||
e.mu.Unlock()
|
||||
|
||||
ch := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
if err != nil {
|
||||
ch <- cliproxyexecutor.StreamChunk{Err: err}
|
||||
close(ch)
|
||||
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil
|
||||
}
|
||||
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(auth.ID)}
|
||||
close(ch)
|
||||
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: 500, Message: "not implemented"}
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) ExecuteCalls() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.executeCalls))
|
||||
copy(out, e.executeCalls)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) StreamCalls() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.streamCalls))
|
||||
copy(out, e.streamCalls)
|
||||
return out
|
||||
}
|
||||
|
||||
func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) {
|
||||
t.Helper()
|
||||
|
||||
@@ -191,6 +261,153 @@ func TestManager_MaxRetryCredentials_LimitsCrossCredentialRetries(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_ModelSupportBadRequest_FallsBackAndSuspendsAuth(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
executor := &authFallbackExecutor{
|
||||
id: "claude",
|
||||
executeErrors: map[string]error{
|
||||
"aa-bad-auth": &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
},
|
||||
},
|
||||
}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
model := "claude-opus-4-6"
|
||||
badAuth := &Auth{ID: "aa-bad-auth", Provider: "claude"}
|
||||
goodAuth := &Auth{ID: "bb-good-auth", Provider: "claude"}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(badAuth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
reg.RegisterClient(goodAuth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(badAuth.ID)
|
||||
reg.UnregisterClient(goodAuth.ID)
|
||||
})
|
||||
|
||||
if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil {
|
||||
t.Fatalf("register bad auth: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil {
|
||||
t.Fatalf("register good auth: %v", errRegister)
|
||||
}
|
||||
|
||||
request := cliproxyexecutor.Request{Model: model}
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, errExecute := m.Execute(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
|
||||
if errExecute != nil {
|
||||
t.Fatalf("execute %d error = %v, want success", i, errExecute)
|
||||
}
|
||||
if string(resp.Payload) != goodAuth.ID {
|
||||
t.Fatalf("execute %d payload = %q, want %q", i, string(resp.Payload), goodAuth.ID)
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.ExecuteCalls()
|
||||
want := []string{badAuth.ID, goodAuth.ID, goodAuth.ID}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d auth = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
updatedBad, ok := m.GetByID(badAuth.ID)
|
||||
if !ok || updatedBad == nil {
|
||||
t.Fatalf("expected bad auth to remain registered")
|
||||
}
|
||||
state := updatedBad.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state for %q", model)
|
||||
}
|
||||
if !state.Unavailable {
|
||||
t.Fatalf("expected bad auth model state to be unavailable")
|
||||
}
|
||||
if state.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected bad auth model state cooldown to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_ModelSupportBadRequestFallsBackAndSuspendsAuth(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
executor := &authFallbackExecutor{
|
||||
id: "claude",
|
||||
streamFirstErrors: map[string]error{
|
||||
"aa-bad-auth": &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
},
|
||||
},
|
||||
}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
model := "claude-opus-4-6"
|
||||
badAuth := &Auth{ID: "aa-bad-auth", Provider: "claude"}
|
||||
goodAuth := &Auth{ID: "bb-good-auth", Provider: "claude"}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(badAuth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
reg.RegisterClient(goodAuth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(badAuth.ID)
|
||||
reg.UnregisterClient(goodAuth.ID)
|
||||
})
|
||||
|
||||
if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil {
|
||||
t.Fatalf("register bad auth: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil {
|
||||
t.Fatalf("register good auth: %v", errRegister)
|
||||
}
|
||||
|
||||
request := cliproxyexecutor.Request{Model: model}
|
||||
for i := 0; i < 2; i++ {
|
||||
streamResult, errExecute := m.ExecuteStream(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
|
||||
if errExecute != nil {
|
||||
t.Fatalf("execute stream %d error = %v, want success", i, errExecute)
|
||||
}
|
||||
var payload []byte
|
||||
for chunk := range streamResult.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("execute stream %d chunk error = %v, want success", i, chunk.Err)
|
||||
}
|
||||
payload = append(payload, chunk.Payload...)
|
||||
}
|
||||
if string(payload) != goodAuth.ID {
|
||||
t.Fatalf("execute stream %d payload = %q, want %q", i, string(payload), goodAuth.ID)
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.StreamCalls()
|
||||
want := []string{badAuth.ID, goodAuth.ID, goodAuth.ID}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("stream calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("stream call %d auth = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
updatedBad, ok := m.GetByID(badAuth.ID)
|
||||
if !ok || updatedBad == nil {
|
||||
t.Fatalf("expected bad auth to remain registered")
|
||||
}
|
||||
state := updatedBad.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state for %q", model)
|
||||
}
|
||||
if !state.Unavailable {
|
||||
t.Fatalf("expected bad auth model state to be unavailable")
|
||||
}
|
||||
if state.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected bad auth model state cooldown to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
|
||||
prev := quotaCooldownDisabled.Load()
|
||||
quotaCooldownDisabled.Store(false)
|
||||
|
||||
@@ -3,6 +3,7 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -116,6 +117,47 @@ func (e *openAICompatPoolExecutor) StreamModels() []string {
|
||||
return out
|
||||
}
|
||||
|
||||
type authScopedOpenAICompatPoolExecutor struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
executeCalls []string
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) Identifier() string { return e.id }
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) Execute(_ context.Context, auth *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
call := auth.ID + "|" + req.Model
|
||||
e.mu.Lock()
|
||||
e.executeCalls = append(e.executeCalls, call)
|
||||
e.mu.Unlock()
|
||||
return cliproxyexecutor.Response{Payload: []byte(call)}, nil
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "ExecuteStream not implemented"}
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) ExecuteCalls() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.executeCalls))
|
||||
copy(out, e.executeCalls)
|
||||
return out
|
||||
}
|
||||
|
||||
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
|
||||
t.Helper()
|
||||
cfg := &internalconfig.Config{
|
||||
@@ -153,6 +195,21 @@ func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []interna
|
||||
return m
|
||||
}
|
||||
|
||||
func readOpenAICompatStreamPayload(t *testing.T, streamResult *cliproxyexecutor.StreamResult) string {
|
||||
t.Helper()
|
||||
if streamResult == nil {
|
||||
t.Fatal("expected stream result")
|
||||
}
|
||||
var payload []byte
|
||||
for chunk := range streamResult.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||
}
|
||||
payload = append(payload, chunk.Payload...)
|
||||
}
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
|
||||
@@ -243,6 +300,87 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
|
||||
t.Fatalf("execute calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute error = %v, want fallback success", err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("pool-auth-" + t.Name())
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to remain registered")
|
||||
}
|
||||
state := updated.ModelStates["qwen3.5-plus"]
|
||||
if state == nil {
|
||||
t.Fatalf("expected suspended upstream model state")
|
||||
}
|
||||
if !state.Unavailable || state.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected upstream model suspension, got %+v", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessableEntity(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusUnprocessableEntity,
|
||||
Message: "The requested model is not supported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute error = %v, want fallback success", err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{
|
||||
@@ -364,6 +502,84 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test
|
||||
t.Fatalf("stream calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute %d: %v", i, err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("execute %d payload = %q, want %q", i, string(resp.Payload), "glm-5")
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusUnprocessableEntity,
|
||||
Message: "The requested model is not supported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute stream %d: %v", i, err)
|
||||
}
|
||||
if payload := readOpenAICompatStreamPayload(t, streamResult); payload != "glm-5" {
|
||||
t.Fatalf("execute stream %d payload = %q, want %q", i, payload, "glm-5")
|
||||
}
|
||||
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
|
||||
t.Fatalf("execute stream %d header X-Model = %q, want %q", i, gotHeader, "glm-5")
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.StreamModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("stream calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{id: "pool"}
|
||||
@@ -391,6 +607,127 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is unsupported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
countErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute count %d: %v", i, err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("execute count %d payload = %q, want %q", i, string(resp.Payload), "glm-5")
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.CountModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("count calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudget(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
cfg := &internalconfig.Config{
|
||||
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
|
||||
Name: "pool",
|
||||
Models: []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
},
|
||||
}},
|
||||
}
|
||||
m := NewManager(nil, nil, nil)
|
||||
m.SetConfig(cfg)
|
||||
m.SetRetryConfig(0, 0, 1)
|
||||
|
||||
executor := &authScopedOpenAICompatPoolExecutor{id: "pool"}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
badAuth := &Auth{
|
||||
ID: "aa-blocked-auth",
|
||||
Provider: "pool",
|
||||
Status: StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"api_key": "bad-key",
|
||||
"compat_name": "pool",
|
||||
"provider_key": "pool",
|
||||
},
|
||||
}
|
||||
goodAuth := &Auth{
|
||||
ID: "bb-good-auth",
|
||||
Provider: "pool",
|
||||
Status: StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"api_key": "good-key",
|
||||
"compat_name": "pool",
|
||||
"provider_key": "pool",
|
||||
},
|
||||
}
|
||||
if _, err := m.Register(context.Background(), badAuth); err != nil {
|
||||
t.Fatalf("register bad auth: %v", err)
|
||||
}
|
||||
if _, err := m.Register(context.Background(), goodAuth); err != nil {
|
||||
t.Fatalf("register good auth: %v", err)
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(badAuth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
|
||||
reg.RegisterClient(goodAuth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(badAuth.ID)
|
||||
reg.UnregisterClient(goodAuth.ID)
|
||||
})
|
||||
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
}
|
||||
for _, upstreamModel := range []string{"qwen3.5-plus", "glm-5"} {
|
||||
m.MarkResult(context.Background(), Result{
|
||||
AuthID: badAuth.ID,
|
||||
Provider: "pool",
|
||||
Model: upstreamModel,
|
||||
Success: false,
|
||||
Error: modelSupportErr,
|
||||
})
|
||||
}
|
||||
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute error = %v, want success via fallback auth", err)
|
||||
}
|
||||
if !strings.HasPrefix(string(resp.Payload), goodAuth.ID+"|") {
|
||||
t.Fatalf("payload = %q, want auth %q", string(resp.Payload), goodAuth.ID)
|
||||
}
|
||||
|
||||
got := executor.ExecuteCalls()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("execute calls = %v, want only one real execution on fallback auth", got)
|
||||
}
|
||||
if !strings.HasPrefix(got[0], goodAuth.ID+"|") {
|
||||
t.Fatalf("execute call = %q, want fallback auth %q", got[0], goodAuth.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
|
||||
|
||||
@@ -443,6 +443,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
|
||||
case "github-copilot":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||
case "codebuddy":
|
||||
s.coreManager.RegisterExecutor(executor.NewCodeBuddyExecutor(s.cfg))
|
||||
case "gitlab":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitLabExecutor(s.cfg))
|
||||
default:
|
||||
@@ -954,6 +956,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "gitlab":
|
||||
models = executor.GitLabModelsFromAuth(a)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "codebuddy":
|
||||
models = registry.GetCodeBuddyModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
if s.cfg != nil {
|
||||
@@ -1006,6 +1011,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if modelID == "" {
|
||||
modelID = m.Name
|
||||
}
|
||||
thinking := m.Thinking
|
||||
if thinking == nil {
|
||||
thinking = ®istry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
|
||||
}
|
||||
ms = append(ms, &ModelInfo{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
@@ -1013,7 +1022,8 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
OwnedBy: compat.Name,
|
||||
Type: "openai-compatibility",
|
||||
DisplayName: modelID,
|
||||
UserDefined: true,
|
||||
UserDefined: false,
|
||||
Thinking: thinking,
|
||||
})
|
||||
}
|
||||
// Register and return
|
||||
|
||||
@@ -17,6 +17,7 @@ type Record struct {
|
||||
AuthIndex string
|
||||
Source string
|
||||
RequestedAt time.Time
|
||||
Latency time.Duration
|
||||
Failed bool
|
||||
Detail Detail
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@ package translator
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// Registry manages translation functions across schemas.
|
||||
@@ -39,7 +43,9 @@ func (r *Registry) Register(from, to Format, request RequestTransform, response
|
||||
}
|
||||
|
||||
// TranslateRequest converts a payload between schemas, returning the original payload
|
||||
// if no translator is registered.
|
||||
// if no translator is registered. When falling back to the original payload, the
|
||||
// "model" field is still updated to match the resolved model name so that
|
||||
// client-side prefixes (e.g. "copilot/gpt-5-mini") are not leaked upstream.
|
||||
func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
@@ -49,6 +55,13 @@ func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byt
|
||||
return fn(model, rawJSON, stream)
|
||||
}
|
||||
}
|
||||
if model != "" && gjson.GetBytes(rawJSON, "model").String() != model {
|
||||
if updated, err := sjson.SetBytes(rawJSON, "model", model); err != nil {
|
||||
log.Warnf("translator: failed to normalize model in request fallback: %v", err)
|
||||
} else {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
|
||||
92
sdk/translator/registry_test.go
Normal file
92
sdk/translator/registry_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestTranslateRequest_FallbackNormalizesModel(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
payload string
|
||||
wantModel string
|
||||
wantUnchanged bool
|
||||
}{
|
||||
{
|
||||
name: "prefixed model is rewritten",
|
||||
model: "gpt-5-mini",
|
||||
payload: `{"model":"copilot/gpt-5-mini","input":"ping"}`,
|
||||
wantModel: "gpt-5-mini",
|
||||
},
|
||||
{
|
||||
name: "matching model is left unchanged",
|
||||
model: "gpt-5-mini",
|
||||
payload: `{"model":"gpt-5-mini","input":"ping"}`,
|
||||
wantModel: "gpt-5-mini",
|
||||
wantUnchanged: true,
|
||||
},
|
||||
{
|
||||
name: "empty model leaves payload unchanged",
|
||||
model: "",
|
||||
payload: `{"model":"copilot/gpt-5-mini","input":"ping"}`,
|
||||
wantModel: "copilot/gpt-5-mini",
|
||||
wantUnchanged: true,
|
||||
},
|
||||
{
|
||||
name: "deeply prefixed model is rewritten",
|
||||
model: "gpt-5.3-codex",
|
||||
payload: `{"model":"team/gpt-5.3-codex","stream":true}`,
|
||||
wantModel: "gpt-5.3-codex",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
input := []byte(tt.payload)
|
||||
got := r.TranslateRequest(Format("a"), Format("b"), tt.model, input, false)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != tt.wantModel {
|
||||
t.Errorf("model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
|
||||
if tt.wantUnchanged && string(got) != tt.payload {
|
||||
t.Errorf("payload was modified when it should not have been:\ngot: %s\nwant: %s", got, tt.payload)
|
||||
}
|
||||
|
||||
// Verify other fields are preserved.
|
||||
for _, key := range []string{"input", "stream"} {
|
||||
orig := gjson.Get(tt.payload, key)
|
||||
if !orig.Exists() {
|
||||
continue
|
||||
}
|
||||
after := gjson.GetBytes(got, key)
|
||||
if orig.Raw != after.Raw {
|
||||
t.Errorf("field %q changed: got %s, want %s", key, after.Raw, orig.Raw)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateRequest_RegisteredTransformTakesPrecedence(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
from := Format("openai-response")
|
||||
to := Format("openai-response")
|
||||
|
||||
r.Register(from, to, func(model string, rawJSON []byte, stream bool) []byte {
|
||||
return []byte(`{"model":"from-transform"}`)
|
||||
}, ResponseTransform{})
|
||||
|
||||
input := []byte(`{"model":"copilot/gpt-5-mini","input":"ping"}`)
|
||||
got := r.TranslateRequest(from, to, "gpt-5-mini", input, false)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != "from-transform" {
|
||||
t.Errorf("expected registered transform to take precedence, got model = %q", gotModel)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user