mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-03 11:12:46 +00:00
Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dc4ecb1b8 | ||
|
|
1315f710f5 | ||
|
|
96f55570f7 | ||
|
|
0906aeca87 | ||
|
|
97c0487add | ||
|
|
a576088d5f | ||
|
|
66ff916838 | ||
|
|
7b0453074e | ||
|
|
a000eb523d | ||
|
|
18a4fedc7f | ||
|
|
5d6cdccda0 | ||
|
|
1b7f4ac3e1 | ||
|
|
afc1a5b814 | ||
|
|
7ed38db54f | ||
|
|
28c10f4e69 | ||
|
|
6e12441a3b | ||
|
|
65c439c18d | ||
|
|
0ed2d16596 | ||
|
|
db335ac616 | ||
|
|
e8bb350467 | ||
|
|
5331d51f27 | ||
|
|
755ca75879 | ||
|
|
2398ebad55 | ||
|
|
c1bf298216 | ||
|
|
e005208d76 | ||
|
|
d1df70d02f | ||
|
|
52c1fa025e | ||
|
|
680105f84d | ||
|
|
f7069e9548 | ||
|
|
793840cdb4 | ||
|
|
8f421de532 | ||
|
|
be2dd60ee7 | ||
|
|
ea3e0b713e | ||
|
|
8179d5a8a4 | ||
|
|
6fa7abe434 | ||
|
|
5135c22cd6 | ||
|
|
1e27990561 | ||
|
|
e1e9fc43c1 | ||
|
|
b2921518ac | ||
|
|
dd64adbeeb | ||
|
|
616d41c06a | ||
|
|
e0e337aeb9 | ||
|
|
d52839fced | ||
|
|
c3762328a5 |
@@ -1,6 +1,6 @@
|
||||
# CLIProxyAPI Plus
|
||||
|
||||
[English](README.md) | 中文
|
||||
[English](README.md) | 中文 | [日本語](README_JA.md)
|
||||
|
||||
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
||||
|
||||
|
||||
183
README_JA.md
Normal file
183
README_JA.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# CLI Proxy API
|
||||
|
||||
[English](README.md) | [中文](README_CN.md) | 日本語
|
||||
|
||||
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
|
||||
|
||||
OAuth経由でOpenAI Codex(GPTモデル)およびClaude Codeもサポートしています。
|
||||
|
||||
ローカルまたはマルチアカウントのCLIアクセスを、OpenAI(Responses含む)/Gemini/Claude互換のクライアントやSDKで利用できます。
|
||||
|
||||
## スポンサー
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
|
||||
|
||||
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7および(GLM-5はProユーザーのみ利用可能)モデルを10以上の人気AIコーディングツール(Claude Code、Cline、Roo Codeなど)で利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
|
||||
|
||||
GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
|
||||
---
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## 概要
|
||||
|
||||
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
|
||||
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
|
||||
- OAuthログインによるClaude Codeサポート
|
||||
- OAuthログインによるQwen Codeサポート
|
||||
- OAuthログインによるiFlowサポート
|
||||
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
|
||||
- ストリーミングおよび非ストリーミングレスポンス
|
||||
- 関数呼び出し/ツールのサポート
|
||||
- マルチモーダル入力サポート(テキストと画像)
|
||||
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||
- シンプルなCLI認証フロー(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||
- Generative Language APIキーのサポート
|
||||
- AI Studioビルドのマルチアカウント負荷分散
|
||||
- Gemini CLIのマルチアカウント負荷分散
|
||||
- Claude Codeのマルチアカウント負荷分散
|
||||
- Qwen Codeのマルチアカウント負荷分散
|
||||
- iFlowのマルチアカウント負荷分散
|
||||
- OpenAI Codexのマルチアカウント負荷分散
|
||||
- 設定によるOpenAI互換アップストリームプロバイダー(例:OpenRouter)
|
||||
- プロキシ埋め込み用の再利用可能なGo SDK(`docs/sdk-usage.md`を参照)
|
||||
|
||||
## はじめに
|
||||
|
||||
CLIProxyAPIガイド:[https://help.router-for.me/ja/](https://help.router-for.me/ja/)
|
||||
|
||||
## 管理API
|
||||
|
||||
[MANAGEMENT_API.md](https://help.router-for.me/ja/management/api)を参照
|
||||
|
||||
## Amp CLIサポート
|
||||
|
||||
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます:
|
||||
|
||||
- Ampの APIパターン用のプロバイダールートエイリアス(`/api/provider/{provider}/v1...`)
|
||||
- OAuth認証およびアカウント機能用の管理プロキシ
|
||||
- 自動ルーティングによるスマートモデルフォールバック
|
||||
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
||||
|
||||
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/ja/agent-client/amp-cli.html)**
|
||||
|
||||
## SDKドキュメント
|
||||
|
||||
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
|
||||
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
|
||||
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
|
||||
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
|
||||
- カスタムプロバイダーの例:`examples/custom-provider`
|
||||
|
||||
## コントリビューション
|
||||
|
||||
コントリビューションを歓迎します!お気軽にPull Requestを送ってください。
|
||||
|
||||
1. リポジトリをフォーク
|
||||
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`)
|
||||
3. 変更をコミット(`git commit -m 'Add some amazing feature'`)
|
||||
4. ブランチにプッシュ(`git push origin feature/amazing-feature`)
|
||||
5. Pull Requestを作成
|
||||
|
||||
## 関連プロジェクト
|
||||
|
||||
CLIProxyAPIをベースにした以下のプロジェクトがあります:
|
||||
|
||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
||||
|
||||
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
|
||||
|
||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
||||
|
||||
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
|
||||
|
||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
||||
|
||||
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデル(Gemini、Codex、Antigravity)を即座に切り替えるCLIラッパー - APIキー不要
|
||||
|
||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
||||
|
||||
CLIProxyAPI管理用のmacOSネイティブGUI:OAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
|
||||
|
||||
### [CodMate](https://github.com/loocor/CodMate)
|
||||
|
||||
CLI AIセッション(Codex、Claude Code、Gemini CLI)を管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
|
||||
|
||||
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||
|
||||
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
|
||||
|
||||
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||
|
||||
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
|
||||
|
||||
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
|
||||
|
||||
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
|
||||
|
||||
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||
|
||||
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
|
||||
|
||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||
|
||||
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替え(Main / Plus)、自動ダウンロードおよび自動更新に対応
|
||||
|
||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||
|
||||
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
|
||||
|
||||
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
|
||||
|
||||
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
|
||||
|
||||
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||
|
||||
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
|
||||
|
||||
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
|
||||
|
||||
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LAN(ローカルエリアネットワーク)を介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
## その他の選択肢
|
||||
|
||||
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです:
|
||||
|
||||
### [9Router](https://github.com/decolua/9router)
|
||||
|
||||
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換(OpenAI/Claude/Gemini/Ollama)、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツール(Cursor、Claude Code、Cline、RooCode)のサポートをゼロから構築 - APIキー不要
|
||||
|
||||
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
|
||||
|
||||
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
|
||||
|
||||
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:スマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
## ライセンス
|
||||
|
||||
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。
|
||||
@@ -175,12 +175,19 @@ nonstream-keepalive-interval: 0
|
||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
||||
# when the client omits them, while OS/arch remain runtime-derived. When
|
||||
# stabilize-device-profile is enabled, OS/arch stay pinned to the baseline values below,
|
||||
# while user-agent/package-version/runtime-version seed a software fingerprint that can
|
||||
# still upgrade to newer official Claude client versions.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# os: "MacOS"
|
||||
# arch: "arm64"
|
||||
# timeout: "600"
|
||||
# stabilize-device-profile: false # optional, default false; set true to enable per-auth/API-key fingerprint pinning
|
||||
|
||||
# Default headers for Codex OAuth model requests.
|
||||
# These are used only for file-backed/OAuth Codex requests when the client
|
||||
|
||||
@@ -748,10 +748,25 @@ func (h *Handler) authIDForPath(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
path = filepath.Clean(path)
|
||||
if !filepath.IsAbs(path) {
|
||||
if abs, errAbs := filepath.Abs(path); errAbs == nil {
|
||||
path = abs
|
||||
}
|
||||
}
|
||||
id := path
|
||||
if h != nil && h.cfg != nil {
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if resolvedAuthDir, errResolve := util.ResolveAuthDir(authDir); errResolve == nil && resolvedAuthDir != "" {
|
||||
authDir = resolvedAuthDir
|
||||
}
|
||||
if authDir != "" {
|
||||
authDir = filepath.Clean(authDir)
|
||||
if !filepath.IsAbs(authDir) {
|
||||
if abs, errAbs := filepath.Abs(authDir); errAbs == nil {
|
||||
authDir = abs
|
||||
}
|
||||
}
|
||||
if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
|
||||
55
internal/config/claude_header_defaults_test.go
Normal file
55
internal/config/claude_header_defaults_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfigOptional_ClaudeHeaderDefaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
configYAML := []byte(`
|
||||
claude-header-defaults:
|
||||
user-agent: " claude-cli/2.1.70 (external, cli) "
|
||||
package-version: " 0.80.0 "
|
||||
runtime-version: " v24.5.0 "
|
||||
os: " MacOS "
|
||||
arch: " arm64 "
|
||||
timeout: " 900 "
|
||||
stabilize-device-profile: false
|
||||
`)
|
||||
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfigOptional(configPath, false)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||
}
|
||||
|
||||
if got := cfg.ClaudeHeaderDefaults.UserAgent; got != "claude-cli/2.1.70 (external, cli)" {
|
||||
t.Fatalf("UserAgent = %q, want %q", got, "claude-cli/2.1.70 (external, cli)")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.PackageVersion; got != "0.80.0" {
|
||||
t.Fatalf("PackageVersion = %q, want %q", got, "0.80.0")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.RuntimeVersion; got != "v24.5.0" {
|
||||
t.Fatalf("RuntimeVersion = %q, want %q", got, "v24.5.0")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.OS; got != "MacOS" {
|
||||
t.Fatalf("OS = %q, want %q", got, "MacOS")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.Arch; got != "arm64" {
|
||||
t.Fatalf("Arch = %q, want %q", got, "arm64")
|
||||
}
|
||||
if got := cfg.ClaudeHeaderDefaults.Timeout; got != "900" {
|
||||
t.Fatalf("Timeout = %q, want %q", got, "900")
|
||||
}
|
||||
if cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||
t.Fatal("StabilizeDeviceProfile = nil, want non-nil")
|
||||
}
|
||||
if got := *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile; got {
|
||||
t.Fatalf("StabilizeDeviceProfile = %v, want false", got)
|
||||
}
|
||||
}
|
||||
@@ -145,13 +145,19 @@ type Config struct {
|
||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests.
|
||||
// In legacy mode, UserAgent/PackageVersion/RuntimeVersion/Timeout act as fallbacks when
|
||||
// the client omits them, while OS/Arch remain runtime-derived. When stabilized device
|
||||
// profiles are enabled, OS/Arch become the pinned platform baseline, while
|
||||
// UserAgent/PackageVersion/RuntimeVersion seed the upgradeable software fingerprint.
|
||||
type ClaudeHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
OS string `yaml:"os" json:"os"`
|
||||
Arch string `yaml:"arch" json:"arch"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
StabilizeDeviceProfile *bool `yaml:"stabilize-device-profile,omitempty" json:"stabilize-device-profile,omitempty"`
|
||||
}
|
||||
|
||||
// CodexHeaderDefaults configures fallback header values injected into Codex
|
||||
@@ -694,6 +700,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Sanitize Codex header defaults.
|
||||
cfg.SanitizeCodexHeaderDefaults()
|
||||
|
||||
// Sanitize Claude header defaults.
|
||||
cfg.SanitizeClaudeHeaderDefaults()
|
||||
|
||||
// Sanitize Claude key headers
|
||||
cfg.SanitizeClaudeKeys()
|
||||
|
||||
@@ -796,6 +805,20 @@ func (cfg *Config) SanitizeCodexHeaderDefaults() {
|
||||
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||
}
|
||||
|
||||
// SanitizeClaudeHeaderDefaults trims surrounding whitespace from the
|
||||
// configured Claude fingerprint baseline values.
|
||||
func (cfg *Config) SanitizeClaudeHeaderDefaults() {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.ClaudeHeaderDefaults.UserAgent = strings.TrimSpace(cfg.ClaudeHeaderDefaults.UserAgent)
|
||||
cfg.ClaudeHeaderDefaults.PackageVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.PackageVersion)
|
||||
cfg.ClaudeHeaderDefaults.RuntimeVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.RuntimeVersion)
|
||||
cfg.ClaudeHeaderDefaults.OS = strings.TrimSpace(cfg.ClaudeHeaderDefaults.OS)
|
||||
cfg.ClaudeHeaderDefaults.Arch = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Arch)
|
||||
cfg.ClaudeHeaderDefaults.Timeout = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Timeout)
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
|
||||
383
internal/runtime/executor/claude_device_profile.go
Normal file
383
internal/runtime/executor/claude_device_profile.go
Normal file
@@ -0,0 +1,383 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClaudeFingerprintUserAgent = "claude-cli/2.1.63 (external, cli)"
|
||||
defaultClaudeFingerprintPackageVersion = "0.74.0"
|
||||
defaultClaudeFingerprintRuntimeVersion = "v24.3.0"
|
||||
defaultClaudeFingerprintOS = "MacOS"
|
||||
defaultClaudeFingerprintArch = "arm64"
|
||||
claudeDeviceProfileTTL = 7 * 24 * time.Hour
|
||||
claudeDeviceProfileCleanupPeriod = time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
claudeCLIVersionPattern = regexp.MustCompile(`^claude-cli/(\d+)\.(\d+)\.(\d+)`)
|
||||
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu sync.RWMutex
|
||||
claudeDeviceProfileCacheCleanupOnce sync.Once
|
||||
|
||||
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
|
||||
)
|
||||
|
||||
type claudeCLIVersion struct {
|
||||
major int
|
||||
minor int
|
||||
patch int
|
||||
}
|
||||
|
||||
func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
|
||||
switch {
|
||||
case v.major != other.major:
|
||||
if v.major > other.major {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
case v.minor != other.minor:
|
||||
if v.minor > other.minor {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
case v.patch != other.patch:
|
||||
if v.patch > other.patch {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
type claudeDeviceProfile struct {
|
||||
UserAgent string
|
||||
PackageVersion string
|
||||
RuntimeVersion string
|
||||
OS string
|
||||
Arch string
|
||||
Version claudeCLIVersion
|
||||
HasVersion bool
|
||||
}
|
||||
|
||||
type claudeDeviceProfileCacheEntry struct {
|
||||
profile claudeDeviceProfile
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
||||
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||
return false
|
||||
}
|
||||
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
|
||||
}
|
||||
|
||||
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if strings.TrimSpace(cfgVal) != "" {
|
||||
return strings.TrimSpace(cfgVal)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hd config.ClaudeHeaderDefaults
|
||||
if cfg != nil {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
profile := claudeDeviceProfile{
|
||||
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
|
||||
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
|
||||
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
|
||||
OS: hdrDefault(hd.OS, defaultClaudeFingerprintOS),
|
||||
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
|
||||
}
|
||||
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||
profile.Version = version
|
||||
profile.HasVersion = true
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
|
||||
matches := claudeCLIVersionPattern.FindStringSubmatch(strings.TrimSpace(userAgent))
|
||||
if len(matches) != 4 {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
major, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
minor, err := strconv.Atoi(matches[2])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
patch, err := strconv.Atoi(matches[3])
|
||||
if err != nil {
|
||||
return claudeCLIVersion{}, false
|
||||
}
|
||||
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
|
||||
}
|
||||
|
||||
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
|
||||
if candidate.UserAgent == "" || !candidate.HasVersion {
|
||||
return false
|
||||
}
|
||||
if current.UserAgent == "" || !current.HasVersion {
|
||||
return true
|
||||
}
|
||||
return candidate.Version.Compare(current.Version) > 0
|
||||
}
|
||||
|
||||
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||
profile.OS = baseline.OS
|
||||
profile.Arch = baseline.Arch
|
||||
return profile
|
||||
}
|
||||
|
||||
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
|
||||
// baseline platform and enforces the baseline software fingerprint as a floor.
|
||||
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
|
||||
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
||||
profile.UserAgent = baseline.UserAgent
|
||||
profile.PackageVersion = baseline.PackageVersion
|
||||
profile.RuntimeVersion = baseline.RuntimeVersion
|
||||
profile.Version = baseline.Version
|
||||
profile.HasVersion = baseline.HasVersion
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
|
||||
if headers == nil {
|
||||
return claudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
||||
version, ok := parseClaudeCLIVersion(userAgent)
|
||||
if !ok {
|
||||
return claudeDeviceProfile{}, false
|
||||
}
|
||||
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
profile := claudeDeviceProfile{
|
||||
UserAgent: userAgent,
|
||||
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
|
||||
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
|
||||
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
|
||||
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
|
||||
Version: version,
|
||||
HasVersion: true,
|
||||
}
|
||||
return profile, true
|
||||
}
|
||||
|
||||
func firstNonEmptyHeader(headers http.Header, name, fallback string) string {
|
||||
if headers == nil {
|
||||
return fallback
|
||||
}
|
||||
if value := strings.TrimSpace(headers.Get(name)); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func claudeDeviceProfileScopeKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||
switch {
|
||||
case auth != nil && strings.TrimSpace(auth.ID) != "":
|
||||
return "auth:" + strings.TrimSpace(auth.ID)
|
||||
case strings.TrimSpace(apiKey) != "":
|
||||
return "api_key:" + strings.TrimSpace(apiKey)
|
||||
default:
|
||||
return "global"
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDeviceProfileCacheKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(claudeDeviceProfileScopeKey(auth, apiKey)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func startClaudeDeviceProfileCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(claudeDeviceProfileCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredClaudeDeviceProfiles()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredClaudeDeviceProfiles() {
|
||||
now := time.Now()
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
for key, entry := range claudeDeviceProfileCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(claudeDeviceProfileCache, key)
|
||||
}
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
|
||||
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
|
||||
|
||||
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
|
||||
now := time.Now()
|
||||
baseline := defaultClaudeDeviceProfile(cfg)
|
||||
candidate, hasCandidate := extractClaudeDeviceProfile(headers, cfg)
|
||||
if hasCandidate {
|
||||
candidate = pinClaudeDeviceProfilePlatform(candidate, baseline)
|
||||
}
|
||||
if hasCandidate && !shouldUpgradeClaudeDeviceProfile(candidate, baseline) {
|
||||
hasCandidate = false
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.RLock()
|
||||
entry, hasCached := claudeDeviceProfileCache[cacheKey]
|
||||
cachedValid := hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||
claudeDeviceProfileCacheMu.RUnlock()
|
||||
|
||||
if hasCandidate {
|
||||
if claudeDeviceProfileBeforeCandidateStore != nil {
|
||||
claudeDeviceProfileBeforeCandidateStore(candidate)
|
||||
}
|
||||
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
entry, hasCached = claudeDeviceProfileCache[cacheKey]
|
||||
cachedValid = hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||
if cachedValid {
|
||||
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||
}
|
||||
if cachedValid && !shouldUpgradeClaudeDeviceProfile(candidate, entry.profile) {
|
||||
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||
claudeDeviceProfileCache[cacheKey] = entry
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return entry.profile
|
||||
}
|
||||
|
||||
claudeDeviceProfileCache[cacheKey] = claudeDeviceProfileCacheEntry{
|
||||
profile: candidate,
|
||||
expire: now.Add(claudeDeviceProfileTTL),
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return candidate
|
||||
}
|
||||
|
||||
if cachedValid {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
entry = claudeDeviceProfileCache[cacheKey]
|
||||
if entry.expire.After(now) && entry.profile.UserAgent != "" {
|
||||
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||
claudeDeviceProfileCache[cacheKey] = entry
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
return entry.profile
|
||||
}
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
return baseline
|
||||
}
|
||||
|
||||
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
for _, headerName := range []string{
|
||||
"User-Agent",
|
||||
"X-Stainless-Package-Version",
|
||||
"X-Stainless-Runtime-Version",
|
||||
"X-Stainless-Os",
|
||||
"X-Stainless-Arch",
|
||||
} {
|
||||
r.Header.Del(headerName)
|
||||
}
|
||||
r.Header.Set("User-Agent", profile.UserAgent)
|
||||
r.Header.Set("X-Stainless-Package-Version", profile.PackageVersion)
|
||||
r.Header.Set("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||
r.Header.Set("X-Stainless-Os", profile.OS)
|
||||
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
||||
}
|
||||
|
||||
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
profile := defaultClaudeDeviceProfile(cfg)
|
||||
miscEnsure := func(name, fallback string) {
|
||||
if strings.TrimSpace(r.Header.Get(name)) != "" {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ginHeaders.Get(name)) != "" {
|
||||
r.Header.Set(name, strings.TrimSpace(ginHeaders.Get(name)))
|
||||
return
|
||||
}
|
||||
r.Header.Set(name, fallback)
|
||||
}
|
||||
|
||||
miscEnsure("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||
miscEnsure("X-Stainless-Package-Version", profile.PackageVersion)
|
||||
miscEnsure("X-Stainless-Os", mapStainlessOS())
|
||||
miscEnsure("X-Stainless-Arch", mapStainlessArch())
|
||||
|
||||
// Legacy mode preserves per-auth custom header overrides. By the time we get
|
||||
// here, ApplyCustomHeadersFromAttrs has already populated r.Header.
|
||||
if strings.TrimSpace(r.Header.Get("User-Agent")) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = strings.TrimSpace(ginHeaders.Get("User-Agent"))
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
return
|
||||
}
|
||||
r.Header.Set("User-Agent", profile.UserAgent)
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -767,36 +766,6 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if cfgVal != "" {
|
||||
@@ -824,6 +793,11 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
|
||||
var deviceProfile claudeDeviceProfile
|
||||
if stabilizeDeviceProfile {
|
||||
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
|
||||
}
|
||||
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
@@ -867,25 +841,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
// For User-Agent, only forward the client's header if it's already a Claude Code client.
|
||||
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
|
||||
// to avoid leaking the real client identity during cloaking.
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = ginHeaders.Get("User-Agent")
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
} else {
|
||||
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
|
||||
}
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
@@ -897,13 +855,19 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
r.Header.Set("Accept", "application/json")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
// Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
|
||||
// to the configured baseline while still allowing newer official
|
||||
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
if stabilizeDeviceProfile {
|
||||
applyClaudeDeviceProfileHeaders(r, deviceProfile)
|
||||
} else {
|
||||
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
|
||||
}
|
||||
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||
|
||||
@@ -8,8 +8,11 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -19,6 +22,587 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func resetClaudeDeviceProfileCache() {
|
||||
claudeDeviceProfileCacheMu.Lock()
|
||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||
claudeDeviceProfileCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginReq := httptest.NewRequest(http.MethodPost, "http://localhost/v1/messages", nil)
|
||||
ginReq.Header = incoming.Clone()
|
||||
ginCtx.Request = ginReq
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
|
||||
return req.WithContext(context.WithValue(req.Context(), "gin", ginCtx))
|
||||
}
|
||||
|
||||
func assertClaudeFingerprint(t *testing.T, headers http.Header, userAgent, pkgVersion, runtimeVersion, osName, arch string) {
|
||||
t.Helper()
|
||||
|
||||
if got := headers.Get("User-Agent"); got != userAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", got, userAgent)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Package-Version"); got != pkgVersion {
|
||||
t.Fatalf("X-Stainless-Package-Version = %q, want %q", got, pkgVersion)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Runtime-Version"); got != runtimeVersion {
|
||||
t.Fatalf("X-Stainless-Runtime-Version = %q, want %q", got, runtimeVersion)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Os"); got != osName {
|
||||
t.Fatalf("X-Stainless-Os = %q, want %q", got, osName)
|
||||
}
|
||||
if got := headers.Get("X-Stainless-Arch"); got != arch {
|
||||
t.Fatalf("X-Stainless-Arch = %q, want %q", got, arch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
Timeout: "900",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline",
|
||||
"header:User-Agent": "evil-client/9.9",
|
||||
"header:X-Stainless-Os": "Linux",
|
||||
"header:X-Stainless-Arch": "x64",
|
||||
"header:X-Stainless-Package-Version": "9.9.9",
|
||||
},
|
||||
}
|
||||
incoming := http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, incoming)
|
||||
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
|
||||
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_TracksHighestClaudeCLIFingerprint(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-upgrade",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
firstReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(firstReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"lobe-chat/1.0"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
|
||||
|
||||
higherReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||
"X-Stainless-Os": []string{"MacOS"},
|
||||
"X-Stainless-Arch": []string{"arm64"},
|
||||
})
|
||||
applyClaudeHeaders(higherReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, higherReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
|
||||
|
||||
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.73.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(lowerReq, auth, "key-upgrade", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_DoesNotDowngradeConfiguredBaselineOnFirstClaudeClient(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline-floor",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline-floor",
|
||||
},
|
||||
}
|
||||
|
||||
olderClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(olderClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, olderClaudeReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
|
||||
newerClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.81.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(newerClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, newerClaudeReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UpgradesCachedSoftwareFingerprintWhenBaselineAdvances(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
oldCfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.77 (external, cli)",
|
||||
PackageVersion: "0.87.0",
|
||||
RuntimeVersion: "v24.8.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-baseline-reload",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-baseline-reload",
|
||||
},
|
||||
}
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.81.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-baseline-reload", false, nil, oldCfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-baseline-reload", false, nil, newCfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LearnsOfficialFingerprintAfterCustomBaselineFallback(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "my-gateway/1.0",
|
||||
PackageVersion: "custom-pkg",
|
||||
RuntimeVersion: "custom-runtime",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-custom-baseline-learning",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-custom-baseline-learning",
|
||||
},
|
||||
}
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "my-gateway/1.0", "custom-pkg", "custom-runtime", "MacOS", "arm64")
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.87.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
|
||||
postLearningThirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(postLearningThirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, postLearningThirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-racy-upgrade",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-racy-upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
lowPaused := make(chan struct{})
|
||||
releaseLow := make(chan struct{})
|
||||
var pauseOnce sync.Once
|
||||
var releaseOnce sync.Once
|
||||
|
||||
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
|
||||
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
|
||||
return
|
||||
}
|
||||
pauseOnce.Do(func() { close(lowPaused) })
|
||||
<-releaseLow
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
claudeDeviceProfileBeforeCandidateStore = nil
|
||||
releaseOnce.Do(func() { close(releaseLow) })
|
||||
})
|
||||
|
||||
lowResultCh := make(chan claudeDeviceProfile, 1)
|
||||
go func() {
|
||||
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
}, cfg)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lowPaused:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for lower candidate to pause before storing")
|
||||
}
|
||||
|
||||
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||
"X-Stainless-Os": []string{"MacOS"},
|
||||
"X-Stainless-Arch": []string{"arm64"},
|
||||
}, cfg)
|
||||
releaseOnce.Do(func() { close(releaseLow) })
|
||||
|
||||
select {
|
||||
case lowResult := <-lowResultCh:
|
||||
if lowResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("lowResult.UserAgent = %q, want %q", lowResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if lowResult.PackageVersion != "0.75.0" {
|
||||
t.Fatalf("lowResult.PackageVersion = %q, want %q", lowResult.PackageVersion, "0.75.0")
|
||||
}
|
||||
if lowResult.OS != "MacOS" || lowResult.Arch != "arm64" {
|
||||
t.Fatalf("lowResult platform = %s/%s, want %s/%s", lowResult.OS, lowResult.Arch, "MacOS", "arm64")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for lower candidate result")
|
||||
}
|
||||
|
||||
if highResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("highResult.UserAgent = %q, want %q", highResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if highResult.OS != "MacOS" || highResult.Arch != "arm64" {
|
||||
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
|
||||
}
|
||||
|
||||
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
}, cfg)
|
||||
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||
t.Fatalf("cached.UserAgent = %q, want %q", cached.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||
}
|
||||
if cached.PackageVersion != "0.75.0" {
|
||||
t.Fatalf("cached.PackageVersion = %q, want %q", cached.PackageVersion, "0.75.0")
|
||||
}
|
||||
if cached.OS != "MacOS" || cached.Arch != "arm64" {
|
||||
t.Fatalf("cached platform = %s/%s, want %s/%s", cached.OS, cached.Arch, "MacOS", "arm64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_ThirdPartyBaselineThenOfficialUpgradeKeepsPinnedPlatform(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
stabilize := true
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||
PackageVersion: "0.80.0",
|
||||
RuntimeVersion: "v24.5.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-third-party-then-official",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-third-party-then-official",
|
||||
},
|
||||
}
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-third-party-then-official", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||
|
||||
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.87.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(officialReq, auth, "key-third-party-then-official", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_DisableDeviceProfileStabilization(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-disable-stability",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-disable-stability",
|
||||
},
|
||||
}
|
||||
|
||||
firstReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(firstReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "Linux", "x64")
|
||||
|
||||
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"lobe-chat/1.0"},
|
||||
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(thirdPartyReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.60 (external, cli)", "0.10.0", "v18.0.0", "Windows", "x64")
|
||||
|
||||
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.73.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
|
||||
"X-Stainless-Os": []string{"Windows"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(lowerReq, auth, "key-disable-stability", false, nil, cfg)
|
||||
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.61 (external, cli)", "0.73.0", "v24.2.0", "Windows", "x64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LegacyModePreservesConfiguredUserAgentOverrideForClaudeClients(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-legacy-ua-override",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-legacy-ua-override",
|
||||
"header:User-Agent": "config-ua/1.0",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||
"X-Stainless-Os": []string{"Linux"},
|
||||
"X-Stainless-Arch": []string{"x64"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-legacy-ua-override", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "config-ua/1.0", "0.74.0", "v24.3.0", "Linux", "x64")
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
stabilize := false
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
StabilizeDeviceProfile: &stabilize,
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-legacy-runtime-os-arch",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-legacy-runtime-os-arch",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||
}
|
||||
|
||||
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
|
||||
resetClaudeDeviceProfileCache()
|
||||
|
||||
cfg := &config.Config{
|
||||
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||
PackageVersion: "0.70.0",
|
||||
RuntimeVersion: "v22.0.0",
|
||||
OS: "MacOS",
|
||||
Arch: "arm64",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-unset-runtime-os-arch",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "key-unset-runtime-os-arch",
|
||||
},
|
||||
}
|
||||
|
||||
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||
"User-Agent": []string{"curl/8.7.1"},
|
||||
})
|
||||
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
|
||||
|
||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||
}
|
||||
|
||||
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
|
||||
if claudeDeviceProfileStabilizationEnabled(nil) {
|
||||
t.Fatal("expected nil config to default to disabled stabilization")
|
||||
}
|
||||
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
|
||||
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
@@ -73,17 +73,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Failed: failed,
|
||||
Detail: detail,
|
||||
})
|
||||
usage.PublishRecord(ctx, r.buildRecord(detail, failed))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -96,20 +86,39 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Failed: false,
|
||||
Detail: usage.Detail{},
|
||||
})
|
||||
usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
||||
if r == nil {
|
||||
return usage.Record{Detail: detail, Failed: failed}
|
||||
}
|
||||
return usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: r.model,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
RequestedAt: r.requestedAt,
|
||||
Latency: r.latency(),
|
||||
Failed: failed,
|
||||
Detail: detail,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageReporter) latency() time.Duration {
|
||||
if r == nil || r.requestedAt.IsZero() {
|
||||
return 0
|
||||
}
|
||||
latency := time.Since(r.requestedAt)
|
||||
if latency < 0 {
|
||||
return 0
|
||||
}
|
||||
return latency
|
||||
}
|
||||
|
||||
func apiKeyFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package executor
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
||||
@@ -41,3 +46,19 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
|
||||
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
||||
reporter := &usageReporter{
|
||||
provider: "openai",
|
||||
model: "gpt-5.4",
|
||||
requestedAt: time.Now().Add(-1500 * time.Millisecond),
|
||||
}
|
||||
|
||||
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
|
||||
if record.Latency < time.Second {
|
||||
t.Fatalf("latency = %v, want >= 1s", record.Latency)
|
||||
}
|
||||
if record.Latency > 3*time.Second {
|
||||
t.Fatalf("latency = %v, want <= 3s", record.Latency)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,59 +104,59 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// Always try cached signature first (more reliable than client-provided)
|
||||
// Client may send stale or invalid signatures from different sessions
|
||||
signature := ""
|
||||
if thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
|
||||
signature = cachedSig
|
||||
// log.Debugf("Using cached signature for thinking block")
|
||||
}
|
||||
signature := ""
|
||||
if thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
|
||||
signature = cachedSig
|
||||
// log.Debugf("Using cached signature for thinking block")
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to client signature only if cache miss and client signature is valid
|
||||
if signature == "" {
|
||||
signatureResult := contentResult.Get("signature")
|
||||
clientSignature := ""
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
|
||||
if len(arrayClientSignatures) == 2 {
|
||||
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
|
||||
clientSignature = arrayClientSignatures[1]
|
||||
}
|
||||
// Fallback to client signature only if cache miss and client signature is valid
|
||||
if signature == "" {
|
||||
signatureResult := contentResult.Get("signature")
|
||||
clientSignature := ""
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
|
||||
if len(arrayClientSignatures) == 2 {
|
||||
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
|
||||
clientSignature = arrayClientSignatures[1]
|
||||
}
|
||||
}
|
||||
if cache.HasValidSignature(modelName, clientSignature) {
|
||||
signature = clientSignature
|
||||
}
|
||||
// log.Debugf("Using client-provided signature for thinking block")
|
||||
}
|
||||
if cache.HasValidSignature(modelName, clientSignature) {
|
||||
signature = clientSignature
|
||||
}
|
||||
// log.Debugf("Using client-provided signature for thinking block")
|
||||
}
|
||||
|
||||
// Store for subsequent tool_use in the same message
|
||||
if cache.HasValidSignature(modelName, signature) {
|
||||
currentMessageThinkingSignature = signature
|
||||
}
|
||||
// Store for subsequent tool_use in the same message
|
||||
if cache.HasValidSignature(modelName, signature) {
|
||||
currentMessageThinkingSignature = signature
|
||||
}
|
||||
|
||||
// Skip trailing unsigned thinking blocks on last assistant message
|
||||
isUnsigned := !cache.HasValidSignature(modelName, signature)
|
||||
// Skip trailing unsigned thinking blocks on last assistant message
|
||||
isUnsigned := !cache.HasValidSignature(modelName, signature)
|
||||
|
||||
// If unsigned, skip entirely (don't convert to text)
|
||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||
// Converting to text would break this requirement
|
||||
if isUnsigned {
|
||||
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||
enableThoughtTranslate = false
|
||||
continue
|
||||
}
|
||||
// If unsigned, skip entirely (don't convert to text)
|
||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||
// Converting to text would break this requirement
|
||||
if isUnsigned {
|
||||
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||
enableThoughtTranslate = false
|
||||
continue
|
||||
}
|
||||
|
||||
// Valid signature, send as thought block
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
|
||||
if thinkingText != "" {
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
|
||||
}
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
// Valid signature, send as thought block
|
||||
// Always include "text" field — Google Antigravity API requires it
|
||||
// even for redacted thinking where the text is empty.
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
// Skip empty text parts to avoid Gemini API error:
|
||||
@@ -171,7 +171,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||
// Antigravity API validates signatures, so dummy values are rejected.
|
||||
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
|
||||
argsResult := contentResult.Get("input")
|
||||
functionID := contentResult.Get("id").String()
|
||||
|
||||
@@ -233,7 +233,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
functionResponseJSON := []byte(`{}`)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", funcName)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", util.SanitizeFunctionName(funcName))
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
@@ -398,6 +398,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
|
||||
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
|
||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||
for toolKey := range gjson.ParseBytes(tool).Map() {
|
||||
if util.InArray(allowedToolKeys, toolKey) {
|
||||
continue
|
||||
@@ -471,7 +472,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
case "tool":
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,10 @@ type Params struct {
|
||||
|
||||
// Signature caching support
|
||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||
|
||||
// Reverse map: sanitized Gemini function name → original Claude tool name.
|
||||
// Populated lazily on the first response chunk from the original request JSON.
|
||||
ToolNameMap map[string]string
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -71,6 +75,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
@@ -212,7 +217,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude Code API compatibility
|
||||
params.HasToolUse = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName(params.ToolNameMap, functionCallResult.Get("name").String())
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
@@ -348,7 +353,7 @@ func resolveStopReason(params *Params) string {
|
||||
// Returns:
|
||||
// - []byte: A Claude-compatible JSON response.
|
||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
_ = originalRequestRawJSON
|
||||
toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
@@ -450,7 +455,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
flushText()
|
||||
hasToolCall = true
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
|
||||
toolIDCounter++
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
|
||||
@@ -286,7 +286,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
@@ -309,7 +309,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
@@ -384,7 +384,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
fnRaw = string(fnRawBytes)
|
||||
}
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
fnRawBytes := []byte(fnRaw)
|
||||
fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String()))
|
||||
fnRaw, _ = sjson.Delete(string(fnRawBytes), "strict")
|
||||
if !hasFunction {
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
@@ -26,6 +27,7 @@ type convertCliResponseToOpenAIChatParams struct {
|
||||
FunctionIndex int
|
||||
SawToolCall bool // Tracks if any tool call was seen in the entire stream
|
||||
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -48,10 +50,14 @@ var functionCallIDCounter uint64
|
||||
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &convertCliResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return [][]byte{}
|
||||
@@ -159,7 +165,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
}
|
||||
|
||||
functionCallTemplate := []byte(`{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
|
||||
@@ -36,6 +36,18 @@ type ToolCallAccumulator struct {
|
||||
Arguments strings.Builder
|
||||
}
|
||||
|
||||
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
completionTokens = usage.Get("output_tokens").Int()
|
||||
cachedTokens = usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
|
||||
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
|
||||
totalTokens = promptTokens + completionTokens
|
||||
|
||||
return promptTokens, completionTokens, totalTokens, cachedTokens
|
||||
}
|
||||
|
||||
// ConvertClaudeResponseToOpenAI converts Claude Code streaming response format to OpenAI Chat Completions format.
|
||||
// This function processes various Claude Code event types and transforms them into OpenAI-compatible JSON responses.
|
||||
// It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match
|
||||
@@ -207,14 +219,11 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
|
||||
// Handle usage information for token counts
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", inputTokens+outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||
}
|
||||
return [][]byte{template}
|
||||
|
||||
@@ -366,14 +375,11 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
}
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", inputTokens+outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
out := ConvertClaudeResponseToOpenAI(
|
||||
ctx,
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
nil,
|
||||
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":13,"output_tokens":4,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}`),
|
||||
¶m,
|
||||
)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||
}
|
||||
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||
}
|
||||
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||
}
|
||||
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
|
||||
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
|
||||
|
||||
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
|
||||
|
||||
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||
}
|
||||
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||
}
|
||||
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||
}
|
||||
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package responses
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -39,6 +40,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
|
||||
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
|
||||
rawJSON = convertSystemRoleToDeveloper(rawJSON)
|
||||
rawJSON = normalizeCodexBuiltinTools(rawJSON)
|
||||
|
||||
return rawJSON
|
||||
}
|
||||
@@ -82,3 +84,59 @@ func convertSystemRoleToDeveloper(rawJSON []byte) []byte {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// normalizeCodexBuiltinTools rewrites legacy/preview built-in tool variants to the
|
||||
// stable names expected by the current Codex upstream.
|
||||
func normalizeCodexBuiltinTools(rawJSON []byte) []byte {
|
||||
result := rawJSON
|
||||
|
||||
tools := gjson.GetBytes(result, "tools")
|
||||
if tools.IsArray() {
|
||||
toolArray := tools.Array()
|
||||
for i := 0; i < len(toolArray); i++ {
|
||||
typePath := fmt.Sprintf("tools.%d.type", i)
|
||||
result = normalizeCodexBuiltinToolAtPath(result, typePath)
|
||||
}
|
||||
}
|
||||
|
||||
result = normalizeCodexBuiltinToolAtPath(result, "tool_choice.type")
|
||||
|
||||
toolChoiceTools := gjson.GetBytes(result, "tool_choice.tools")
|
||||
if toolChoiceTools.IsArray() {
|
||||
toolArray := toolChoiceTools.Array()
|
||||
for i := 0; i < len(toolArray); i++ {
|
||||
typePath := fmt.Sprintf("tool_choice.tools.%d.type", i)
|
||||
result = normalizeCodexBuiltinToolAtPath(result, typePath)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeCodexBuiltinToolAtPath(rawJSON []byte, path string) []byte {
|
||||
currentType := gjson.GetBytes(rawJSON, path).String()
|
||||
normalizedType := normalizeCodexBuiltinToolType(currentType)
|
||||
if normalizedType == "" {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
updated, err := sjson.SetBytes(rawJSON, path, normalizedType)
|
||||
if err != nil {
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
log.Debugf("codex responses: normalized builtin tool type at %s from %q to %q", path, currentType, normalizedType)
|
||||
return updated
|
||||
}
|
||||
|
||||
// normalizeCodexBuiltinToolType centralizes the current known Codex Responses
|
||||
// built-in tool alias compatibility. If Codex introduces more legacy aliases,
|
||||
// extend this helper instead of adding path-specific rewrite logic elsewhere.
|
||||
func normalizeCodexBuiltinToolType(toolType string) string {
|
||||
switch toolType {
|
||||
case "web_search_preview", "web_search_preview_2025_03_11":
|
||||
return "web_search"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,6 +264,52 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIResponsesRequestToCodex_NormalizesWebSearchPreview(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.4-mini",
|
||||
"input": "find latest OpenAI model news",
|
||||
"tools": [
|
||||
{"type": "web_search_preview_2025_03_11"}
|
||||
],
|
||||
"tool_choice": {
|
||||
"type": "allowed_tools",
|
||||
"tools": [
|
||||
{"type": "web_search_preview"},
|
||||
{"type": "web_search_preview_2025_03_11"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
|
||||
|
||||
if got := gjson.GetBytes(output, "tools.0.type").String(); got != "web_search" {
|
||||
t.Fatalf("tools.0.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "allowed_tools" {
|
||||
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "allowed_tools", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.tools.0.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.tools.0.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
if got := gjson.GetBytes(output, "tool_choice.tools.1.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.tools.1.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIResponsesRequestToCodex_NormalizesTopLevelToolChoicePreviewAlias(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.4-mini",
|
||||
"input": "find latest OpenAI model news",
|
||||
"tool_choice": {"type": "web_search_preview_2025_03_11"}
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
|
||||
|
||||
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "web_search", string(output))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
|
||||
@@ -89,7 +89,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
|
||||
case "tool_use":
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() && gjson.Valid(functionArgs) {
|
||||
@@ -112,7 +112,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.name", util.SanitizeFunctionName(funcName))
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
|
||||
@@ -151,6 +151,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
|
||||
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
|
||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||
tool, _ = sjson.DeleteBytes(tool, "strict")
|
||||
tool, _ = sjson.DeleteBytes(tool, "input_examples")
|
||||
tool, _ = sjson.DeleteBytes(tool, "type")
|
||||
@@ -194,7 +195,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
case "tool":
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,9 @@ type Params struct {
|
||||
ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function
|
||||
ResponseIndex int // Index counter for content blocks in the streaming response
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
|
||||
// Reverse map: sanitized Gemini function name → original Claude tool name.
|
||||
ToolNameMap map[string]string
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -55,6 +58,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +169,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude Code API compatibility
|
||||
usedTool = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName((*param).(*Params).ToolNameMap, functionCallResult.Get("name").String())
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
@@ -248,7 +252,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// Returns:
|
||||
// - []byte: A Claude-compatible JSON response.
|
||||
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
_ = originalRequestRawJSON
|
||||
toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
_ = requestRawJSON
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
@@ -306,7 +310,7 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
flushText()
|
||||
hasToolCall = true
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
|
||||
toolIDCounter++
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
|
||||
@@ -251,7 +251,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
@@ -268,7 +268,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
@@ -331,6 +331,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
continue
|
||||
}
|
||||
}
|
||||
fnRaw, _ = sjson.SetBytes(fnRaw, "name", util.SanitizeFunctionName(fn.Get("name").String()))
|
||||
fnRaw, _ = sjson.DeleteBytes(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -21,8 +22,9 @@ import (
|
||||
|
||||
// convertCliResponseToOpenAIChatParams holds parameters for response conversion.
|
||||
type convertCliResponseToOpenAIChatParams struct {
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -45,10 +47,14 @@ var functionCallIDCounter uint64
|
||||
func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &convertCliResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return [][]byte{}
|
||||
@@ -163,7 +169,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
}
|
||||
|
||||
functionCallTemplate := []byte(`{"id":"","index":0,"type":"function","function":{"name":"","arguments":""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -90,6 +91,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
functionName = derived
|
||||
}
|
||||
}
|
||||
functionName = util.SanitizeFunctionName(functionName)
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() && gjson.Valid(functionArgs) {
|
||||
@@ -109,6 +111,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
if funcName == "" {
|
||||
funcName = toolCallID
|
||||
}
|
||||
funcName = util.SanitizeFunctionName(funcName)
|
||||
responseData := contentResult.Get("content").Raw
|
||||
part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.name", funcName)
|
||||
@@ -165,6 +168,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
tool, _ = sjson.DeleteBytes(tool, "type")
|
||||
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
||||
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
|
||||
if !hasTools {
|
||||
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[{"functionDeclarations":[]}]`))
|
||||
@@ -202,7 +206,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
case "tool":
|
||||
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ type Params struct {
|
||||
ResponseIndex int
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
ToolNameMap map[string]string
|
||||
SanitizedNameMap map[string]string
|
||||
SawToolCall bool
|
||||
}
|
||||
|
||||
@@ -57,6 +58,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.ToolNameMapFromClaudeRequest(originalRequestRawJSON),
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
SawToolCall: false,
|
||||
}
|
||||
}
|
||||
@@ -167,6 +169,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
// This processes tool usage requests and formats them for Claude API compatibility
|
||||
(*param).(*Params).SawToolCall = true
|
||||
upstreamToolName := functionCallResult.Get("name").String()
|
||||
upstreamToolName = util.RestoreSanitizedToolName((*param).(*Params).SanitizedNameMap, upstreamToolName)
|
||||
clientToolName := util.MapToolName((*param).(*Params).ToolNameMap, upstreamToolName)
|
||||
|
||||
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
|
||||
@@ -260,6 +263,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
|
||||
sanitizedNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
|
||||
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
out, _ = sjson.SetBytes(out, "id", root.Get("responseId").String())
|
||||
@@ -315,6 +319,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
hasToolCall = true
|
||||
|
||||
upstreamToolName := functionCall.Get("name").String()
|
||||
upstreamToolName = util.RestoreSanitizedToolName(sanitizedNameMap, upstreamToolName)
|
||||
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
|
||||
toolIDCounter++
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
|
||||
@@ -257,7 +257,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
@@ -274,7 +274,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
@@ -341,6 +341,9 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
fnRaw = string(fnRawBytes)
|
||||
}
|
||||
fnRawBytes := []byte(fnRaw)
|
||||
fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String()))
|
||||
fnRaw = string(fnRawBytes)
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -22,7 +23,8 @@ import (
|
||||
type convertGeminiResponseToOpenAIChatParams struct {
|
||||
UnixTimestamp int64
|
||||
// FunctionIndex tracks tool call indices per candidate index to support multiple candidates.
|
||||
FunctionIndex map[int]int
|
||||
FunctionIndex map[int]int
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -46,8 +48,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
// Initialize parameters if nil.
|
||||
if *param == nil {
|
||||
*param = &convertGeminiResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: make(map[int]int),
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: make(map[int]int),
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +59,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
if p.FunctionIndex == nil {
|
||||
p.FunctionIndex = make(map[int]int)
|
||||
}
|
||||
if p.SanitizedNameMap == nil {
|
||||
p.SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
@@ -191,7 +197,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
}
|
||||
|
||||
functionCallTemplate := []byte(`{"id":"","index":0,"type":"function","function":{"name":"","arguments":""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName(p.SanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
@@ -265,6 +271,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
// Returns:
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
sanitizedNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
var unixTimestamp int64
|
||||
// Initialize template with an empty choices array to support multiple candidates.
|
||||
template := []byte(`{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[]}`)
|
||||
@@ -358,7 +365,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
choiceTemplate, _ = sjson.SetRawBytes(choiceTemplate, "message.tool_calls", []byte(`[]`))
|
||||
}
|
||||
functionCallItemTemplate := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
fcName := util.RestoreSanitizedToolName(sanitizedNameMap, functionCallResult.Get("name").String())
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -291,7 +292,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
|
||||
case "function_call":
|
||||
// Handle function calls - convert to model message with functionCall
|
||||
name := item.Get("name").String()
|
||||
name := util.SanitizeFunctionName(item.Get("name").String())
|
||||
arguments := item.Get("arguments").String()
|
||||
|
||||
modelContent := []byte(`{"role":"model","parts":[]}`)
|
||||
@@ -333,6 +334,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
return true
|
||||
})
|
||||
}
|
||||
functionName = util.SanitizeFunctionName(functionName)
|
||||
|
||||
functionResponse, _ = sjson.SetBytes(functionResponse, "functionResponse.name", functionName)
|
||||
functionResponse, _ = sjson.SetBytes(functionResponse, "functionResponse.id", callID)
|
||||
@@ -375,7 +377,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
funcDecl := []byte(`{"name":"","description":"","parametersJsonSchema":{}}`)
|
||||
|
||||
if name := tool.Get("name"); name.Exists() {
|
||||
funcDecl, _ = sjson.SetBytes(funcDecl, "name", name.String())
|
||||
funcDecl, _ = sjson.SetBytes(funcDecl, "name", util.SanitizeFunctionName(name.String()))
|
||||
}
|
||||
if desc := tool.Get("description"); desc.Exists() {
|
||||
funcDecl, _ = sjson.SetBytes(funcDecl, "description", desc.String())
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -36,11 +37,12 @@ type geminiToResponsesState struct {
|
||||
ReasoningClosed bool
|
||||
|
||||
// function call aggregation (keyed by output_index)
|
||||
NextIndex int
|
||||
FuncArgsBuf map[int]*strings.Builder
|
||||
FuncNames map[int]string
|
||||
FuncCallIDs map[int]string
|
||||
FuncDone map[int]bool
|
||||
NextIndex int
|
||||
FuncArgsBuf map[int]*strings.Builder
|
||||
FuncNames map[int]string
|
||||
FuncCallIDs map[int]string
|
||||
FuncDone map[int]bool
|
||||
SanitizedNameMap map[string]string
|
||||
}
|
||||
|
||||
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
|
||||
@@ -90,10 +92,11 @@ func emitEvent(event string, payload []byte) []byte {
|
||||
func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &geminiToResponsesState{
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncDone: make(map[int]bool),
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncDone: make(map[int]bool),
|
||||
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
st := (*param).(*geminiToResponsesState)
|
||||
@@ -109,6 +112,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
if st.FuncDone == nil {
|
||||
st.FuncDone = make(map[int]bool)
|
||||
}
|
||||
if st.SanitizedNameMap == nil {
|
||||
st.SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
@@ -306,7 +312,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// Responses streaming requires message done events before the next output_item.added.
|
||||
finalizeReasoning()
|
||||
finalizeMessage()
|
||||
name := fc.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(st.SanitizedNameMap, fc.Get("name").String())
|
||||
idx := st.NextIndex
|
||||
st.NextIndex++
|
||||
// Ensure buffers
|
||||
@@ -565,6 +571,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
root = unwrapGeminiResponseRoot(root)
|
||||
sanitizedNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||
|
||||
// Base response scaffold
|
||||
resp := []byte(`{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`)
|
||||
@@ -694,7 +701,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
return true
|
||||
}
|
||||
if fc := p.Get("functionCall"); fc.Exists() {
|
||||
name := fc.Get("name").String()
|
||||
name := util.RestoreSanitizedToolName(sanitizedNameMap, fc.Get("name").String())
|
||||
args := fc.Get("args")
|
||||
callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
|
||||
itemJSON := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
|
||||
@@ -87,9 +87,10 @@ type modelStats struct {
|
||||
Details []RequestDetail
|
||||
}
|
||||
|
||||
// RequestDetail stores the timestamp and token usage for a single request.
|
||||
// RequestDetail stores the timestamp, latency, and token usage for a single request.
|
||||
type RequestDetail struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
Source string `json:"source"`
|
||||
AuthIndex string `json:"auth_index"`
|
||||
Tokens TokenStats `json:"tokens"`
|
||||
@@ -198,6 +199,7 @@ func (s *RequestStatistics) Record(ctx context.Context, record coreusage.Record)
|
||||
}
|
||||
s.updateAPIStats(stats, modelName, RequestDetail{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: normaliseLatency(record.Latency),
|
||||
Source: record.Source,
|
||||
AuthIndex: record.AuthIndex,
|
||||
Tokens: detail,
|
||||
@@ -332,6 +334,9 @@ func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResu
|
||||
}
|
||||
for _, detail := range modelSnapshot.Details {
|
||||
detail.Tokens = normaliseTokenStats(detail.Tokens)
|
||||
if detail.LatencyMs < 0 {
|
||||
detail.LatencyMs = 0
|
||||
}
|
||||
if detail.Timestamp.IsZero() {
|
||||
detail.Timestamp = time.Now()
|
||||
}
|
||||
@@ -463,6 +468,13 @@ func normaliseTokenStats(tokens TokenStats) TokenStats {
|
||||
return tokens
|
||||
}
|
||||
|
||||
func normaliseLatency(latency time.Duration) int64 {
|
||||
if latency <= 0 {
|
||||
return 0
|
||||
}
|
||||
return latency.Milliseconds()
|
||||
}
|
||||
|
||||
func formatHour(hour int) string {
|
||||
if hour < 0 {
|
||||
hour = 0
|
||||
|
||||
96
internal/usage/logger_plugin_test.go
Normal file
96
internal/usage/logger_plugin_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestRequestStatisticsRecordIncludesLatency(t *testing.T) {
|
||||
stats := NewRequestStatistics()
|
||||
stats.Record(context.Background(), coreusage.Record{
|
||||
APIKey: "test-key",
|
||||
Model: "gpt-5.4",
|
||||
RequestedAt: time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC),
|
||||
Latency: 1500 * time.Millisecond,
|
||||
Detail: coreusage.Detail{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
})
|
||||
|
||||
snapshot := stats.Snapshot()
|
||||
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
|
||||
if len(details) != 1 {
|
||||
t.Fatalf("details len = %d, want 1", len(details))
|
||||
}
|
||||
if details[0].LatencyMs != 1500 {
|
||||
t.Fatalf("latency_ms = %d, want 1500", details[0].LatencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestStatisticsMergeSnapshotDedupIgnoresLatency(t *testing.T) {
|
||||
stats := NewRequestStatistics()
|
||||
timestamp := time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC)
|
||||
first := StatisticsSnapshot{
|
||||
APIs: map[string]APISnapshot{
|
||||
"test-key": {
|
||||
Models: map[string]ModelSnapshot{
|
||||
"gpt-5.4": {
|
||||
Details: []RequestDetail{{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: 0,
|
||||
Source: "user@example.com",
|
||||
AuthIndex: "0",
|
||||
Tokens: TokenStats{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
second := StatisticsSnapshot{
|
||||
APIs: map[string]APISnapshot{
|
||||
"test-key": {
|
||||
Models: map[string]ModelSnapshot{
|
||||
"gpt-5.4": {
|
||||
Details: []RequestDetail{{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: 2500,
|
||||
Source: "user@example.com",
|
||||
AuthIndex: "0",
|
||||
Tokens: TokenStats{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := stats.MergeSnapshot(first)
|
||||
if result.Added != 1 || result.Skipped != 0 {
|
||||
t.Fatalf("first merge = %+v, want added=1 skipped=0", result)
|
||||
}
|
||||
|
||||
result = stats.MergeSnapshot(second)
|
||||
if result.Added != 0 || result.Skipped != 1 {
|
||||
t.Fatalf("second merge = %+v, want added=0 skipped=1", result)
|
||||
}
|
||||
|
||||
snapshot := stats.Snapshot()
|
||||
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
|
||||
if len(details) != 1 {
|
||||
t.Fatalf("details len = %d, want 1", len(details))
|
||||
}
|
||||
}
|
||||
@@ -54,3 +54,77 @@ func TestSanitizeFunctionName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizedToolNameMap(t *testing.T) {
|
||||
t.Run("returns map for tools needing sanitization", func(t *testing.T) {
|
||||
raw := []byte(`{"tools":[
|
||||
{"name":"valid_tool","input_schema":{}},
|
||||
{"name":"mcp/server/read","input_schema":{}},
|
||||
{"name":"tool@v2","input_schema":{}}
|
||||
]}`)
|
||||
m := SanitizedToolNameMap(raw)
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil map")
|
||||
}
|
||||
if m["mcp_server_read"] != "mcp/server/read" {
|
||||
t.Errorf("expected mcp_server_read → mcp/server/read, got %q", m["mcp_server_read"])
|
||||
}
|
||||
if m["tool_v2"] != "tool@v2" {
|
||||
t.Errorf("expected tool_v2 → tool@v2, got %q", m["tool_v2"])
|
||||
}
|
||||
if _, exists := m["valid_tool"]; exists {
|
||||
t.Error("valid_tool should not be in the map (no sanitization needed)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when no tools need sanitization", func(t *testing.T) {
|
||||
raw := []byte(`{"tools":[{"name":"Read","input_schema":{}},{"name":"Write","input_schema":{}}]}`)
|
||||
m := SanitizedToolNameMap(raw)
|
||||
if m != nil {
|
||||
t.Errorf("expected nil, got %v", m)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil for empty/missing tools", func(t *testing.T) {
|
||||
if m := SanitizedToolNameMap([]byte(`{}`)); m != nil {
|
||||
t.Error("expected nil for no tools")
|
||||
}
|
||||
if m := SanitizedToolNameMap(nil); m != nil {
|
||||
t.Error("expected nil for nil input")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("collision keeps first mapping", func(t *testing.T) {
|
||||
raw := []byte(`{"tools":[
|
||||
{"name":"read/file","input_schema":{}},
|
||||
{"name":"read@file","input_schema":{}}
|
||||
]}`)
|
||||
m := SanitizedToolNameMap(raw)
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil map")
|
||||
}
|
||||
if m["read_file"] != "read/file" {
|
||||
t.Errorf("expected first mapping read/file, got %q", m["read_file"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRestoreSanitizedToolName(t *testing.T) {
|
||||
m := map[string]string{
|
||||
"mcp_server_read": "mcp/server/read",
|
||||
"tool_v2": "tool@v2",
|
||||
}
|
||||
|
||||
if got := RestoreSanitizedToolName(m, "mcp_server_read"); got != "mcp/server/read" {
|
||||
t.Errorf("expected mcp/server/read, got %q", got)
|
||||
}
|
||||
if got := RestoreSanitizedToolName(m, "unknown"); got != "unknown" {
|
||||
t.Errorf("expected passthrough for unknown, got %q", got)
|
||||
}
|
||||
if got := RestoreSanitizedToolName(nil, "name"); got != "name" {
|
||||
t.Errorf("expected passthrough for nil map, got %q", got)
|
||||
}
|
||||
if got := RestoreSanitizedToolName(m, ""); got != "" {
|
||||
t.Errorf("expected empty for empty name, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -243,6 +244,9 @@ func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string {
|
||||
out := make(map[string]string, len(toolResults))
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
name := strings.TrimSpace(tool.Get("name").String())
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(tool.Get("function.name").String())
|
||||
}
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
@@ -271,3 +275,54 @@ func MapToolName(toolNameMap map[string]string, name string) string {
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// SanitizedToolNameMap builds a sanitized-name → original-name map from Claude request tools.
|
||||
// It is used to restore exact tool names for clients (e.g. Claude Code) after the proxy
|
||||
// sanitizes tool names for Gemini/Vertex API compatibility via SanitizeFunctionName.
|
||||
// Only entries where sanitization actually changes the name are included.
|
||||
func SanitizedToolNameMap(rawJSON []byte) map[string]string {
|
||||
if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) {
|
||||
return nil
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(map[string]string)
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
name := strings.TrimSpace(tool.Get("name").String())
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
sanitized := SanitizeFunctionName(name)
|
||||
if sanitized == name {
|
||||
return true
|
||||
}
|
||||
if _, exists := out[sanitized]; !exists {
|
||||
out[sanitized] = name
|
||||
} else {
|
||||
log.Warnf("sanitized tool name collision: %q and %q both map to %q, keeping first", out[sanitized], name, sanitized)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RestoreSanitizedToolName looks up a sanitized function name in the provided map
|
||||
// and returns the original client-facing name. If no mapping exists, it returns
|
||||
// the sanitized name unchanged.
|
||||
func RestoreSanitizedToolName(toolNameMap map[string]string, sanitizedName string) string {
|
||||
if sanitizedName == "" || toolNameMap == nil {
|
||||
return sanitizedName
|
||||
}
|
||||
if original, ok := toolNameMap[sanitizedName]; ok {
|
||||
return original
|
||||
}
|
||||
return sanitizedName
|
||||
}
|
||||
|
||||
@@ -75,7 +75,12 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
cacheAuthContents := log.IsLevelEnabled(log.DebugLevel)
|
||||
if cacheAuthContents {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
} else {
|
||||
w.lastAuthContents = nil
|
||||
}
|
||||
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
@@ -89,10 +94,12 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
sum := sha256.Sum256(data)
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
// Parse and cache auth content for future diff comparisons
|
||||
var auth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
// Parse and cache auth content for future diff comparisons (debug only).
|
||||
if cacheAuthContents {
|
||||
var auth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
}
|
||||
}
|
||||
ctx := &synthesizer.SynthesisContext{
|
||||
Config: cfg,
|
||||
@@ -102,7 +109,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
}
|
||||
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
|
||||
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
|
||||
w.fileAuthsByPath[normalizedPath] = pathAuths
|
||||
w.fileAuthsByPath[normalizedPath] = authIDSet(pathAuths)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -171,25 +178,30 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
}
|
||||
|
||||
// Get old auth for diff comparison
|
||||
cacheAuthContents := log.IsLevelEnabled(log.DebugLevel)
|
||||
var oldAuth *coreauth.Auth
|
||||
if w.lastAuthContents != nil {
|
||||
if cacheAuthContents && w.lastAuthContents != nil {
|
||||
oldAuth = w.lastAuthContents[normalized]
|
||||
}
|
||||
|
||||
// Compute and log field changes
|
||||
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
|
||||
log.Debugf("auth field changes for %s:", filepath.Base(path))
|
||||
for _, c := range changes {
|
||||
log.Debugf(" %s", c)
|
||||
if cacheAuthContents {
|
||||
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
|
||||
log.Debugf("auth field changes for %s:", filepath.Base(path))
|
||||
for _, c := range changes {
|
||||
log.Debugf(" %s", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update caches
|
||||
w.lastAuthHashes[normalized] = curHash
|
||||
if w.lastAuthContents == nil {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
if cacheAuthContents {
|
||||
if w.lastAuthContents == nil {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
|
||||
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
|
||||
for id, a := range w.fileAuthsByPath[normalized] {
|
||||
@@ -206,7 +218,7 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
|
||||
newByID := authSliceToMap(generated)
|
||||
if len(newByID) > 0 {
|
||||
w.fileAuthsByPath[normalized] = newByID
|
||||
w.fileAuthsByPath[normalized] = authIDSet(newByID)
|
||||
} else {
|
||||
delete(w.fileAuthsByPath, normalized)
|
||||
}
|
||||
@@ -273,6 +285,14 @@ func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
|
||||
return byID
|
||||
}
|
||||
|
||||
func authIDSet(auths map[string]*coreauth.Auth) map[string]*coreauth.Auth {
|
||||
set := make(map[string]*coreauth.Auth, len(auths))
|
||||
for id := range auths {
|
||||
set[id] = nil
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||
authFileCount := 0
|
||||
successfulAuthCount := 0
|
||||
|
||||
@@ -340,12 +340,13 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
}
|
||||
}
|
||||
newCtx, cancel := context.WithCancel(parentCtx)
|
||||
cancelCtx := newCtx
|
||||
if requestCtx != nil && requestCtx != parentCtx {
|
||||
go func() {
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
cancel()
|
||||
case <-newCtx.Done():
|
||||
case <-cancelCtx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -421,10 +421,6 @@ func preserveRequestedModelSuffix(requestedModel, resolved string) string {
|
||||
}
|
||||
|
||||
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
|
||||
return m.prepareExecutionModels(auth, routeModel)
|
||||
}
|
||||
|
||||
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
|
||||
requestedModel := rewriteModelForAuth(routeModel, auth)
|
||||
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
|
||||
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
|
||||
@@ -441,6 +437,46 @@ func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string
|
||||
return []string{resolved}
|
||||
}
|
||||
|
||||
func executionResultModel(routeModel, upstreamModel string, pooled bool) string {
|
||||
if pooled {
|
||||
if resolved := strings.TrimSpace(upstreamModel); resolved != "" {
|
||||
return resolved
|
||||
}
|
||||
}
|
||||
if requested := strings.TrimSpace(routeModel); requested != "" {
|
||||
return requested
|
||||
}
|
||||
return strings.TrimSpace(upstreamModel)
|
||||
}
|
||||
|
||||
func filterExecutionModels(auth *Auth, routeModel string, candidates []string, pooled bool) []string {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
out := make([]string, 0, len(candidates))
|
||||
for _, upstreamModel := range candidates {
|
||||
stateModel := executionResultModel(routeModel, upstreamModel, pooled)
|
||||
blocked, _, _ := isAuthBlockedForModel(auth, stateModel, now)
|
||||
if blocked {
|
||||
continue
|
||||
}
|
||||
out = append(out, upstreamModel)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) preparedExecutionModels(auth *Auth, routeModel string) ([]string, bool) {
|
||||
candidates := m.executionModelCandidates(auth, routeModel)
|
||||
pooled := len(candidates) > 1
|
||||
return filterExecutionModels(auth, routeModel, candidates, pooled), pooled
|
||||
}
|
||||
|
||||
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
|
||||
models, _ := m.preparedExecutionModels(auth, routeModel)
|
||||
return models
|
||||
}
|
||||
|
||||
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
|
||||
if ch == nil {
|
||||
return
|
||||
@@ -451,6 +487,59 @@ func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
|
||||
}()
|
||||
}
|
||||
|
||||
type streamBootstrapError struct {
|
||||
cause error
|
||||
headers http.Header
|
||||
}
|
||||
|
||||
func cloneHTTPHeader(headers http.Header) http.Header {
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
return headers.Clone()
|
||||
}
|
||||
|
||||
func newStreamBootstrapError(err error, headers http.Header) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &streamBootstrapError{
|
||||
cause: err,
|
||||
headers: cloneHTTPHeader(headers),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *streamBootstrapError) Error() string {
|
||||
if e == nil || e.cause == nil {
|
||||
return ""
|
||||
}
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
func (e *streamBootstrapError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.cause
|
||||
}
|
||||
|
||||
func (e *streamBootstrapError) Headers() http.Header {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return cloneHTTPHeader(e.headers)
|
||||
}
|
||||
|
||||
func streamErrorResult(headers http.Header, err error) *cliproxyexecutor.StreamResult {
|
||||
ch := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
ch <- cliproxyexecutor.StreamChunk{Err: err}
|
||||
close(ch)
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: cloneHTTPHeader(headers),
|
||||
Chunks: ch,
|
||||
}
|
||||
}
|
||||
|
||||
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
|
||||
if ch == nil {
|
||||
return nil, true, nil
|
||||
@@ -483,7 +572,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
|
||||
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, resultModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
@@ -496,7 +585,7 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
return false
|
||||
@@ -526,19 +615,19 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro
|
||||
}
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: true})
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
|
||||
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, execModels []string, pooled bool) (*cliproxyexecutor.StreamResult, error) {
|
||||
if executor == nil {
|
||||
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
execModels := m.prepareExecutionModels(auth, routeModel)
|
||||
var lastErr error
|
||||
for idx, execModel := range execModels {
|
||||
resultModel := executionResultModel(routeModel, execModel, pooled)
|
||||
execReq := req
|
||||
execReq.Model = execModel
|
||||
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
|
||||
@@ -550,7 +639,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(ctx, result)
|
||||
if isRequestInvalidError(errStream) {
|
||||
@@ -571,7 +660,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||
m.MarkResult(ctx, result)
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
@@ -582,31 +671,33 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||
m.MarkResult(ctx, result)
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
lastErr = bootstrapErr
|
||||
continue
|
||||
}
|
||||
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
|
||||
close(errCh)
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
|
||||
rerr := &Error{Message: bootstrapErr.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||
m.MarkResult(ctx, result)
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
return nil, newStreamBootstrapError(bootstrapErr, streamResult.Headers)
|
||||
}
|
||||
|
||||
if closed && len(buffered) == 0 {
|
||||
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: emptyErr}
|
||||
m.MarkResult(ctx, result)
|
||||
if idx < len(execModels)-1 {
|
||||
lastErr = emptyErr
|
||||
continue
|
||||
}
|
||||
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
|
||||
close(errCh)
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
|
||||
return nil, newStreamBootstrapError(emptyErr, streamResult.Headers)
|
||||
}
|
||||
|
||||
remaining := streamResult.Chunks
|
||||
@@ -615,7 +706,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
|
||||
close(closedCh)
|
||||
remaining = closedCh
|
||||
}
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, resultModel, streamResult.Headers, buffered, remaining), nil
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
|
||||
@@ -979,9 +1070,10 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
attempted := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
@@ -1006,13 +1098,18 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
|
||||
models := m.prepareExecutionModels(auth, routeModel)
|
||||
models, pooled := m.preparedExecutionModels(auth, routeModel)
|
||||
if len(models) == 0 {
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
var authErr error
|
||||
for _, upstreamModel := range models {
|
||||
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
|
||||
execReq := req
|
||||
execReq.Model = upstreamModel
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
@@ -1051,9 +1148,10 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
attempted := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
@@ -1078,13 +1176,18 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
|
||||
models := m.prepareExecutionModels(auth, routeModel)
|
||||
models, pooled := m.preparedExecutionModels(auth, routeModel)
|
||||
if len(models) == 0 {
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
var authErr error
|
||||
for _, upstreamModel := range models {
|
||||
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
|
||||
execReq := req
|
||||
execReq.Model = upstreamModel
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
@@ -1096,14 +1199,14 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.hook.OnResult(execCtx, result)
|
||||
m.MarkResult(execCtx, result)
|
||||
if isRequestInvalidError(errExec) {
|
||||
return cliproxyexecutor.Response{}, errExec
|
||||
}
|
||||
authErr = errExec
|
||||
continue
|
||||
}
|
||||
m.hook.OnResult(execCtx, result)
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
if authErr != nil {
|
||||
@@ -1123,10 +1226,15 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
attempted := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
var bootstrapErr *streamBootstrapError
|
||||
if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil {
|
||||
return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
@@ -1134,6 +1242,10 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
var bootstrapErr *streamBootstrapError
|
||||
if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil {
|
||||
return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
@@ -1149,7 +1261,12 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
|
||||
models, pooled := m.preparedExecutionModels(auth, routeModel)
|
||||
if len(models) == 0 {
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
@@ -1627,53 +1744,60 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
}
|
||||
|
||||
statusCode := statusCodeFromResult(result.Error)
|
||||
switch statusCode {
|
||||
case 401:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "unauthorized"
|
||||
shouldSuspendModel = true
|
||||
case 402, 403:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "payment_required"
|
||||
shouldSuspendModel = true
|
||||
case 404:
|
||||
if isModelSupportResultError(result.Error) {
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "not_found"
|
||||
suspendReason = "model_not_supported"
|
||||
shouldSuspendModel = true
|
||||
case 429:
|
||||
var next time.Time
|
||||
backoffLevel := state.Quota.BackoffLevel
|
||||
if result.RetryAfter != nil {
|
||||
next = now.Add(*result.RetryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
}
|
||||
backoffLevel = nextLevel
|
||||
}
|
||||
state.NextRetryAfter = next
|
||||
state.Quota = QuotaState{
|
||||
Exceeded: true,
|
||||
Reason: "quota",
|
||||
NextRecoverAt: next,
|
||||
BackoffLevel: backoffLevel,
|
||||
}
|
||||
suspendReason = "quota"
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
case 408, 500, 502, 503, 504:
|
||||
if quotaCooldownDisabledForAuth(auth) {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
} else {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "unauthorized"
|
||||
shouldSuspendModel = true
|
||||
case 402, 403:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "payment_required"
|
||||
shouldSuspendModel = true
|
||||
case 404:
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "not_found"
|
||||
shouldSuspendModel = true
|
||||
case 429:
|
||||
var next time.Time
|
||||
backoffLevel := state.Quota.BackoffLevel
|
||||
if result.RetryAfter != nil {
|
||||
next = now.Add(*result.RetryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
}
|
||||
backoffLevel = nextLevel
|
||||
}
|
||||
state.NextRetryAfter = next
|
||||
state.Quota = QuotaState{
|
||||
Exceeded: true,
|
||||
Reason: "quota",
|
||||
NextRecoverAt: next,
|
||||
BackoffLevel: backoffLevel,
|
||||
}
|
||||
suspendReason = "quota"
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
case 408, 500, 502, 503, 504:
|
||||
if quotaCooldownDisabledForAuth(auth) {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
}
|
||||
default:
|
||||
state.NextRetryAfter = time.Time{}
|
||||
}
|
||||
default:
|
||||
state.NextRetryAfter = time.Time{}
|
||||
}
|
||||
|
||||
auth.Status = StatusError
|
||||
@@ -1883,14 +2007,65 @@ func statusCodeFromResult(err *Error) int {
|
||||
return err.StatusCode()
|
||||
}
|
||||
|
||||
func isModelSupportErrorMessage(message string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(message))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
patterns := [...]string{
|
||||
"model_not_supported",
|
||||
"requested model is not supported",
|
||||
"requested model is unsupported",
|
||||
"requested model is unavailable",
|
||||
"model is not supported",
|
||||
"model not supported",
|
||||
"unsupported model",
|
||||
"model unavailable",
|
||||
"not available for your plan",
|
||||
"not available for your account",
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(lower, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isModelSupportError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
status := statusCodeFromError(err)
|
||||
if status != http.StatusBadRequest && status != http.StatusUnprocessableEntity {
|
||||
return false
|
||||
}
|
||||
return isModelSupportErrorMessage(err.Error())
|
||||
}
|
||||
|
||||
func isModelSupportResultError(err *Error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
status := statusCodeFromResult(err)
|
||||
if status != http.StatusBadRequest && status != http.StatusUnprocessableEntity {
|
||||
return false
|
||||
}
|
||||
return isModelSupportErrorMessage(err.Message)
|
||||
}
|
||||
|
||||
// isRequestInvalidError returns true if the error represents a client request
|
||||
// error that should not be retried. Specifically, it treats 400 responses with
|
||||
// "invalid_request_error" and all 422 responses as request-shape failures,
|
||||
// where switching auths or pooled upstream models will not help.
|
||||
// where switching auths or pooled upstream models will not help. Model-support
|
||||
// errors are excluded so routing can fall through to another auth or upstream.
|
||||
func isRequestInvalidError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if isModelSupportError(err) {
|
||||
return false
|
||||
}
|
||||
status := statusCodeFromError(err)
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
|
||||
@@ -108,6 +108,76 @@ func (e *credentialRetryLimitExecutor) Calls() int {
|
||||
return e.calls
|
||||
}
|
||||
|
||||
type authFallbackExecutor struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
executeCalls []string
|
||||
streamCalls []string
|
||||
executeErrors map[string]error
|
||||
streamFirstErrors map[string]error
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) Identifier() string {
|
||||
return e.id
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) Execute(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
e.mu.Lock()
|
||||
e.executeCalls = append(e.executeCalls, auth.ID)
|
||||
err := e.executeErrors[auth.ID]
|
||||
e.mu.Unlock()
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
return cliproxyexecutor.Response{Payload: []byte(auth.ID)}, nil
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) ExecuteStream(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
e.streamCalls = append(e.streamCalls, auth.ID)
|
||||
err := e.streamFirstErrors[auth.ID]
|
||||
e.mu.Unlock()
|
||||
|
||||
ch := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
if err != nil {
|
||||
ch <- cliproxyexecutor.StreamChunk{Err: err}
|
||||
close(ch)
|
||||
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil
|
||||
}
|
||||
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(auth.ID)}
|
||||
close(ch)
|
||||
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: 500, Message: "not implemented"}
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) ExecuteCalls() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.executeCalls))
|
||||
copy(out, e.executeCalls)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *authFallbackExecutor) StreamCalls() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.streamCalls))
|
||||
copy(out, e.streamCalls)
|
||||
return out
|
||||
}
|
||||
|
||||
func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) {
|
||||
t.Helper()
|
||||
|
||||
@@ -191,6 +261,153 @@ func TestManager_MaxRetryCredentials_LimitsCrossCredentialRetries(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_ModelSupportBadRequest_FallsBackAndSuspendsAuth(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
executor := &authFallbackExecutor{
|
||||
id: "claude",
|
||||
executeErrors: map[string]error{
|
||||
"aa-bad-auth": &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
},
|
||||
},
|
||||
}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
model := "claude-opus-4-6"
|
||||
badAuth := &Auth{ID: "aa-bad-auth", Provider: "claude"}
|
||||
goodAuth := &Auth{ID: "bb-good-auth", Provider: "claude"}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(badAuth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
reg.RegisterClient(goodAuth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(badAuth.ID)
|
||||
reg.UnregisterClient(goodAuth.ID)
|
||||
})
|
||||
|
||||
if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil {
|
||||
t.Fatalf("register bad auth: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil {
|
||||
t.Fatalf("register good auth: %v", errRegister)
|
||||
}
|
||||
|
||||
request := cliproxyexecutor.Request{Model: model}
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, errExecute := m.Execute(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
|
||||
if errExecute != nil {
|
||||
t.Fatalf("execute %d error = %v, want success", i, errExecute)
|
||||
}
|
||||
if string(resp.Payload) != goodAuth.ID {
|
||||
t.Fatalf("execute %d payload = %q, want %q", i, string(resp.Payload), goodAuth.ID)
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.ExecuteCalls()
|
||||
want := []string{badAuth.ID, goodAuth.ID, goodAuth.ID}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d auth = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
updatedBad, ok := m.GetByID(badAuth.ID)
|
||||
if !ok || updatedBad == nil {
|
||||
t.Fatalf("expected bad auth to remain registered")
|
||||
}
|
||||
state := updatedBad.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state for %q", model)
|
||||
}
|
||||
if !state.Unavailable {
|
||||
t.Fatalf("expected bad auth model state to be unavailable")
|
||||
}
|
||||
if state.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected bad auth model state cooldown to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_ModelSupportBadRequestFallsBackAndSuspendsAuth(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
executor := &authFallbackExecutor{
|
||||
id: "claude",
|
||||
streamFirstErrors: map[string]error{
|
||||
"aa-bad-auth": &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
},
|
||||
},
|
||||
}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
model := "claude-opus-4-6"
|
||||
badAuth := &Auth{ID: "aa-bad-auth", Provider: "claude"}
|
||||
goodAuth := &Auth{ID: "bb-good-auth", Provider: "claude"}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(badAuth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
reg.RegisterClient(goodAuth.ID, "claude", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(badAuth.ID)
|
||||
reg.UnregisterClient(goodAuth.ID)
|
||||
})
|
||||
|
||||
if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil {
|
||||
t.Fatalf("register bad auth: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil {
|
||||
t.Fatalf("register good auth: %v", errRegister)
|
||||
}
|
||||
|
||||
request := cliproxyexecutor.Request{Model: model}
|
||||
for i := 0; i < 2; i++ {
|
||||
streamResult, errExecute := m.ExecuteStream(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
|
||||
if errExecute != nil {
|
||||
t.Fatalf("execute stream %d error = %v, want success", i, errExecute)
|
||||
}
|
||||
var payload []byte
|
||||
for chunk := range streamResult.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("execute stream %d chunk error = %v, want success", i, chunk.Err)
|
||||
}
|
||||
payload = append(payload, chunk.Payload...)
|
||||
}
|
||||
if string(payload) != goodAuth.ID {
|
||||
t.Fatalf("execute stream %d payload = %q, want %q", i, string(payload), goodAuth.ID)
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.StreamCalls()
|
||||
want := []string{badAuth.ID, goodAuth.ID, goodAuth.ID}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("stream calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("stream call %d auth = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
updatedBad, ok := m.GetByID(badAuth.ID)
|
||||
if !ok || updatedBad == nil {
|
||||
t.Fatalf("expected bad auth to remain registered")
|
||||
}
|
||||
state := updatedBad.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state for %q", model)
|
||||
}
|
||||
if !state.Unavailable {
|
||||
t.Fatalf("expected bad auth model state to be unavailable")
|
||||
}
|
||||
if state.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected bad auth model state cooldown to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
|
||||
prev := quotaCooldownDisabled.Load()
|
||||
quotaCooldownDisabled.Store(false)
|
||||
|
||||
@@ -3,6 +3,7 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -116,6 +117,47 @@ func (e *openAICompatPoolExecutor) StreamModels() []string {
|
||||
return out
|
||||
}
|
||||
|
||||
type authScopedOpenAICompatPoolExecutor struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
executeCalls []string
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) Identifier() string { return e.id }
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) Execute(_ context.Context, auth *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
call := auth.ID + "|" + req.Model
|
||||
e.mu.Lock()
|
||||
e.executeCalls = append(e.executeCalls, call)
|
||||
e.mu.Unlock()
|
||||
return cliproxyexecutor.Response{Payload: []byte(call)}, nil
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "ExecuteStream not implemented"}
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
|
||||
}
|
||||
|
||||
func (e *authScopedOpenAICompatPoolExecutor) ExecuteCalls() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.executeCalls))
|
||||
copy(out, e.executeCalls)
|
||||
return out
|
||||
}
|
||||
|
||||
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
|
||||
t.Helper()
|
||||
cfg := &internalconfig.Config{
|
||||
@@ -153,6 +195,21 @@ func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []interna
|
||||
return m
|
||||
}
|
||||
|
||||
func readOpenAICompatStreamPayload(t *testing.T, streamResult *cliproxyexecutor.StreamResult) string {
|
||||
t.Helper()
|
||||
if streamResult == nil {
|
||||
t.Fatal("expected stream result")
|
||||
}
|
||||
var payload []byte
|
||||
for chunk := range streamResult.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||
}
|
||||
payload = append(payload, chunk.Payload...)
|
||||
}
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
|
||||
@@ -243,6 +300,87 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
|
||||
t.Fatalf("execute calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute error = %v, want fallback success", err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("pool-auth-" + t.Name())
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to remain registered")
|
||||
}
|
||||
state := updated.ModelStates["qwen3.5-plus"]
|
||||
if state == nil {
|
||||
t.Fatalf("expected suspended upstream model state")
|
||||
}
|
||||
if !state.Unavailable || state.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected upstream model suspension, got %+v", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessableEntity(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusUnprocessableEntity,
|
||||
Message: "The requested model is not supported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute error = %v, want fallback success", err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{
|
||||
@@ -364,6 +502,84 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test
|
||||
t.Fatalf("stream calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute %d: %v", i, err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("execute %d payload = %q, want %q", i, string(resp.Payload), "glm-5")
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusUnprocessableEntity,
|
||||
Message: "The requested model is not supported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute stream %d: %v", i, err)
|
||||
}
|
||||
if payload := readOpenAICompatStreamPayload(t, streamResult); payload != "glm-5" {
|
||||
t.Fatalf("execute stream %d payload = %q, want %q", i, payload, "glm-5")
|
||||
}
|
||||
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
|
||||
t.Fatalf("execute stream %d header X-Model = %q, want %q", i, gotHeader, "glm-5")
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.StreamModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("stream calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{id: "pool"}
|
||||
@@ -391,6 +607,127 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is unsupported.",
|
||||
}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
countErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute count %d: %v", i, err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("execute count %d payload = %q, want %q", i, string(resp.Payload), "glm-5")
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.CountModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("count calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudget(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
cfg := &internalconfig.Config{
|
||||
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
|
||||
Name: "pool",
|
||||
Models: []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
},
|
||||
}},
|
||||
}
|
||||
m := NewManager(nil, nil, nil)
|
||||
m.SetConfig(cfg)
|
||||
m.SetRetryConfig(0, 0, 1)
|
||||
|
||||
executor := &authScopedOpenAICompatPoolExecutor{id: "pool"}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
badAuth := &Auth{
|
||||
ID: "aa-blocked-auth",
|
||||
Provider: "pool",
|
||||
Status: StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"api_key": "bad-key",
|
||||
"compat_name": "pool",
|
||||
"provider_key": "pool",
|
||||
},
|
||||
}
|
||||
goodAuth := &Auth{
|
||||
ID: "bb-good-auth",
|
||||
Provider: "pool",
|
||||
Status: StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"api_key": "good-key",
|
||||
"compat_name": "pool",
|
||||
"provider_key": "pool",
|
||||
},
|
||||
}
|
||||
if _, err := m.Register(context.Background(), badAuth); err != nil {
|
||||
t.Fatalf("register bad auth: %v", err)
|
||||
}
|
||||
if _, err := m.Register(context.Background(), goodAuth); err != nil {
|
||||
t.Fatalf("register good auth: %v", err)
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(badAuth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
|
||||
reg.RegisterClient(goodAuth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(badAuth.ID)
|
||||
reg.UnregisterClient(goodAuth.ID)
|
||||
})
|
||||
|
||||
modelSupportErr := &Error{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "invalid_request_error: The requested model is not supported.",
|
||||
}
|
||||
for _, upstreamModel := range []string{"qwen3.5-plus", "glm-5"} {
|
||||
m.MarkResult(context.Background(), Result{
|
||||
AuthID: badAuth.ID,
|
||||
Provider: "pool",
|
||||
Model: upstreamModel,
|
||||
Success: false,
|
||||
Error: modelSupportErr,
|
||||
})
|
||||
}
|
||||
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute error = %v, want success via fallback auth", err)
|
||||
}
|
||||
if !strings.HasPrefix(string(resp.Payload), goodAuth.ID+"|") {
|
||||
t.Fatalf("payload = %q, want auth %q", string(resp.Payload), goodAuth.ID)
|
||||
}
|
||||
|
||||
got := executor.ExecuteCalls()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("execute calls = %v, want only one real execution on fallback auth", got)
|
||||
}
|
||||
if !strings.HasPrefix(got[0], goodAuth.ID+"|") {
|
||||
t.Fatalf("execute call = %q, want fallback auth %q", got[0], goodAuth.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
|
||||
|
||||
@@ -17,6 +17,7 @@ type Record struct {
|
||||
AuthIndex string
|
||||
Source string
|
||||
RequestedAt time.Time
|
||||
Latency time.Duration
|
||||
Failed bool
|
||||
Detail Detail
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@ package translator
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// Registry manages translation functions across schemas.
|
||||
@@ -39,7 +43,9 @@ func (r *Registry) Register(from, to Format, request RequestTransform, response
|
||||
}
|
||||
|
||||
// TranslateRequest converts a payload between schemas, returning the original payload
|
||||
// if no translator is registered.
|
||||
// if no translator is registered. When falling back to the original payload, the
|
||||
// "model" field is still updated to match the resolved model name so that
|
||||
// client-side prefixes (e.g. "copilot/gpt-5-mini") are not leaked upstream.
|
||||
func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
@@ -49,6 +55,13 @@ func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byt
|
||||
return fn(model, rawJSON, stream)
|
||||
}
|
||||
}
|
||||
if model != "" && gjson.GetBytes(rawJSON, "model").String() != model {
|
||||
if updated, err := sjson.SetBytes(rawJSON, "model", model); err != nil {
|
||||
log.Warnf("translator: failed to normalize model in request fallback: %v", err)
|
||||
} else {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
|
||||
92
sdk/translator/registry_test.go
Normal file
92
sdk/translator/registry_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestTranslateRequest_FallbackNormalizesModel(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
payload string
|
||||
wantModel string
|
||||
wantUnchanged bool
|
||||
}{
|
||||
{
|
||||
name: "prefixed model is rewritten",
|
||||
model: "gpt-5-mini",
|
||||
payload: `{"model":"copilot/gpt-5-mini","input":"ping"}`,
|
||||
wantModel: "gpt-5-mini",
|
||||
},
|
||||
{
|
||||
name: "matching model is left unchanged",
|
||||
model: "gpt-5-mini",
|
||||
payload: `{"model":"gpt-5-mini","input":"ping"}`,
|
||||
wantModel: "gpt-5-mini",
|
||||
wantUnchanged: true,
|
||||
},
|
||||
{
|
||||
name: "empty model leaves payload unchanged",
|
||||
model: "",
|
||||
payload: `{"model":"copilot/gpt-5-mini","input":"ping"}`,
|
||||
wantModel: "copilot/gpt-5-mini",
|
||||
wantUnchanged: true,
|
||||
},
|
||||
{
|
||||
name: "deeply prefixed model is rewritten",
|
||||
model: "gpt-5.3-codex",
|
||||
payload: `{"model":"team/gpt-5.3-codex","stream":true}`,
|
||||
wantModel: "gpt-5.3-codex",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
input := []byte(tt.payload)
|
||||
got := r.TranslateRequest(Format("a"), Format("b"), tt.model, input, false)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != tt.wantModel {
|
||||
t.Errorf("model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
|
||||
if tt.wantUnchanged && string(got) != tt.payload {
|
||||
t.Errorf("payload was modified when it should not have been:\ngot: %s\nwant: %s", got, tt.payload)
|
||||
}
|
||||
|
||||
// Verify other fields are preserved.
|
||||
for _, key := range []string{"input", "stream"} {
|
||||
orig := gjson.Get(tt.payload, key)
|
||||
if !orig.Exists() {
|
||||
continue
|
||||
}
|
||||
after := gjson.GetBytes(got, key)
|
||||
if orig.Raw != after.Raw {
|
||||
t.Errorf("field %q changed: got %s, want %s", key, after.Raw, orig.Raw)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateRequest_RegisteredTransformTakesPrecedence(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
from := Format("openai-response")
|
||||
to := Format("openai-response")
|
||||
|
||||
r.Register(from, to, func(model string, rawJSON []byte, stream bool) []byte {
|
||||
return []byte(`{"model":"from-transform"}`)
|
||||
}, ResponseTransform{})
|
||||
|
||||
input := []byte(`{"model":"copilot/gpt-5-mini","input":"ping"}`)
|
||||
got := r.TranslateRequest(from, to, "gpt-5-mini", input, false)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != "from-transform" {
|
||||
t.Errorf("expected registered transform to take precedence, got model = %q", gotModel)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user