Compare commits

...

55 Commits

Author SHA1 Message Date
Luis Pater
05a35662ae Merge branch 'router-for-me:main' into main 2026-03-09 23:05:51 +08:00
Luis Pater
ce53d3a287 Fixed: #1997
test(auth-scheduler): add benchmarks and priority-based scheduling improvements

- Added `BenchmarkManagerPickNextMixedPriority500` for mixed-priority performance assessment.
- Updated `pickNextMixed` to prioritize highest ready priority tiers.
- Introduced `highestReadyPriorityLocked` and `pickReadyAtPriorityLocked` for better scheduling logic.
- Added unit test to validate selection of highest priority tiers in mixed provider scenarios.
2026-03-09 22:27:15 +08:00
Luis Pater
4cc99e7449 Merge pull request #1992 from dcrdev/main
System prompt silently dropped when sent as a string
2026-03-09 21:03:15 +08:00
Luis Pater
71773fe032 Merge pull request #1996 from router-for-me/codex/fix-unbounded-websocket-log-buffering
fix: cap websocket body log growth in responses handler
2026-03-09 20:50:38 +08:00
Dominic Robinson
a1e0fa0f39 test(executor): cover string system prompt handling in checkSystemInstructionsWithMode 2026-03-09 12:40:27 +00:00
Supra4E8C
fc2f0b6983 fix: cap websocket body log growth 2026-03-09 17:48:30 +08:00
Dominic Robinson
5c9997cdac fix: Preserve system prompt when sent as a string instead of content block array 2026-03-09 07:38:11 +00:00
Luis Pater
6f81046730 docs: remove outdated sections from README and README_CN 2026-03-09 09:35:25 +08:00
Luis Pater
0687472d01 Merge pull request #422 from router-for-me/plus
v6.8.49
2026-03-09 09:34:05 +08:00
Luis Pater
7739738fb3 Merge branch 'main' into plus 2026-03-09 09:33:22 +08:00
Luis Pater
99d1ce247b Merge pull request #420 from Skadli/codex/responses-computer-tool
Fixed: preserve Responses computer tool passthrough
2026-03-09 09:31:30 +08:00
Luis Pater
f5941a411c test(auth): cover scheduler refresh regression paths 2026-03-09 09:27:56 +08:00
Luis Pater
ba672bbd07 Merge PR #1969 into dev 2026-03-09 09:25:06 +08:00
Luis Pater
d9c6627a53 Merge pull request #1963 from qixing-jk/docs/add-all-api-hub-showcase
docs: add All API Hub to related projects list
2026-03-09 09:16:41 +08:00
Luis Pater
2e9907c3ac Merge pull request #1959 from thebtf/fix/system-instruction-camelcase
fix: use camelCase systemInstruction in OpenAI-to-Gemini translators
2026-03-09 09:09:03 +08:00
DragonFSKY
90afb9cb73 fix(auth): new OAuth accounts invisible to scheduler after dynamic registration
When new OAuth auth files are added while the service is running,
`applyCoreAuthAddOrUpdate` calls `coreManager.Register()` (which upserts
into the scheduler) BEFORE `registerModelsForAuth()`. At upsert time,
`buildScheduledAuthMeta` snapshots `supportedModelSetForAuth` from the
global model registry — but models haven't been registered yet, so the
set is empty. With an empty `supportedModelSet`, `supportsModel()`
always returns false and the new auth is never added to any model shard.

Additionally, when all existing accounts are in cooldown, the scheduler
returns `modelCooldownError`, but `shouldRetrySchedulerPick` only
handles `*Error` types — so the `syncScheduler` safety-net rebuild
never triggers and the new accounts remain invisible.

Fix:
1. Add `RefreshSchedulerEntry()` to re-upsert a single auth after its
   models are registered, rebuilding `supportedModelSet` from the
   now-populated registry.
2. Call it from `applyCoreAuthAddOrUpdate` after `registerModelsForAuth`.
3. Make `shouldRetrySchedulerPick` also match `*modelCooldownError` so
   the full scheduler rebuild triggers when all credentials are cooling
   down — catching any similar stale-snapshot edge cases.
2026-03-09 03:11:47 +08:00
anime
d0cc0cd9a5 docs: add All API Hub to related projects list
- Update README.md with All API Hub entry in English
- Update README_CN.md with All API Hub entry in Chinese
2026-03-09 02:00:16 +08:00
Kirill Turanskiy
338321e553 fix: use camelCase systemInstruction in OpenAI-to-Gemini translators
The Gemini v1internal (cloudcode-pa) and Antigravity Manager endpoints
require camelCase "systemInstruction" in request JSON. The current
snake_case "system_instruction" causes system prompts to be silently
ignored when routing through these endpoints.

Replace all "system_instruction" JSON keys with "systemInstruction" in
chat-completions and responses request translators.
2026-03-08 15:59:13 +03:00
Luis Pater
182b31963a Merge branch 'router-for-me:main' into main 2026-03-08 20:48:05 +08:00
Luis Pater
4f48e5254a Merge pull request #1957 from router-for-me/thinking
fix(translator): pass through adaptive thinking effort
2026-03-08 20:46:58 +08:00
Luis Pater
15dd5db1d7 Merge pull request #1956 from router-for-me/vertex
fix(executor): use aiplatform base url for vertex api key calls
2026-03-08 20:46:28 +08:00
hkfires
424711b718 fix(executor): use aiplatform base url for vertex api key calls 2026-03-08 20:13:12 +08:00
skad
91a2b1f0b4 Fixed: preserve Responses computer tool passthrough
Keep the OpenAI Responses computer tool intact when normalizing requests for the GitHub Copilot executor.

This change preserves built-in computer tool definitions instead of dropping them as non-function tools, keeps explicit computer tool_choice selections unchanged, and classifies computer_call / computer_call_output items as assistant and tool turns when deriving the initiator header.

Together these adjustments allow Responses requests that use the computer tool to reach the upstream executor without losing tool metadata or switching turn ownership unexpectedly.
2026-03-08 13:59:32 +08:00
Luis Pater
2b134fc378 test(auth-scheduler): add unit tests and scheduler implementation
- Added comprehensive unit tests for `authScheduler` and related components.
- Implemented `authScheduler` with support for Round Robin, Fill First, and custom selector strategies.
- Improved tracking of auth states, cooldowns, and recovery logic in scheduler.
2026-03-08 05:52:55 +08:00
Luis Pater
b9153719b0 Merge pull request #1925 from shenshuoyaoyouguang/pr/openai-compat-pool-thinking
fix(openai-compat): improve pool fallback and preserve adaptive thinking
2026-03-08 01:05:05 +08:00
Luis Pater
631e5c8331 Merge pull request #1922 from shenshuoyaoyouguang/pr/model-registry-safety
fix(registry): clone model snapshots and invalidate available-model cache
2026-03-07 23:01:42 +08:00
Luis Pater
e9c60a0a67 Merge pull request #1910 from thebtf/fix/gemini-oauth-error-messages
fix: surface upstream error details in Gemini CLI OAuth onboarding UI
2026-03-07 22:25:18 +08:00
Luis Pater
98a1bb5a7f Merge pull request #1900 from rex-zsd/feature/add-gemini-3.1-flash-image-preview
feat(registry): add gemini-3.1-flash-image-preview model definition
2026-03-07 22:17:10 +08:00
Luis Pater
ca90487a8c Merge branch 'main' into feature/add-gemini-3.1-flash-image-preview 2026-03-07 22:16:09 +08:00
Luis Pater
1042489f85 Merge pull request #1893 from thebtf/fix/normalize-ttl-byte-preservation-mainline
fix: preserve original JSON bytes in normalizeCacheControlTTL
2026-03-07 22:14:13 +08:00
Luis Pater
38277c1ea6 Merge pull request #1875 from woqiqishi/fix/tool-use-id-sanitize
fix: sanitize tool_use.id to comply with Claude API regex ^[a-zA-Z0-9_-]+$
2026-03-07 22:06:36 +08:00
Luis Pater
ee0c24628f Merge branch 'router-for-me:main' into main 2026-03-07 20:42:22 +08:00
chujian
3a18f6fcca fix(registry): clone slice fields in model map output 2026-03-07 18:53:56 +08:00
chujian
099e734a02 fix(registry): always clone available model snapshots 2026-03-07 18:40:02 +08:00
chujian
a52da26b5d fix(auth): stop draining stream pool goroutines after context cancellation 2026-03-07 18:30:33 +08:00
chujian
522a68a4ea fix(openai-compat): retry empty bootstrap streams 2026-03-07 18:08:13 +08:00
chujian
a02eda54d0 fix(openai-compat): address review feedback 2026-03-07 17:39:42 +08:00
chujian
97ef633c57 fix(registry): address review feedback 2026-03-07 17:36:57 +08:00
chujian
dae8463ba1 fix(registry): clone model snapshots and invalidate available-model cache 2026-03-07 16:59:23 +08:00
chujian
7c1299922e fix(openai-compat): improve pool fallback and preserve adaptive thinking 2026-03-07 16:54:28 +08:00
Luis Pater
ddcf1f279d Fixed: #1901
test(websocket): add tests for incremental input and prewarm handling logic

- Added test cases for incremental input support based on upstream capabilities.
- Introduced validation for prewarm handling of `response.create` messages locally.
- Enhanced test coverage for websocket executor behavior, including payload forwarding checks.
- Updated websocket implementation with prewarm and incremental input logic for better testability.
2026-03-07 13:11:28 +08:00
Luis Pater
7e6bb8fdc5 Merge origin/dev into pr-1774-review and resolve watcher conflicts 2026-03-07 11:12:42 +08:00
Luis Pater
9cee8ef87b Merge pull request #1684 from alexey-yanchenko/fix/input-audio-from-openai-to-antigravity
fix: preserve input_audio content parts when proxying to Antigravity
2026-03-07 10:12:28 +08:00
Luis Pater
93fb841bcb Fixed: #1670
test(translator): add unit tests for OpenAI to Claude requests and tool result handling

- Introduced tests for converting OpenAI requests to Claude with text, base64 images, and URL images in tool results.
- Refactored `convertClaudeToolResultContent` and related functionality to properly handle raw content with images and text.
- Updated conversion logic to streamline image handling for both base64 and URL formats.
2026-03-07 09:25:22 +08:00
Kirill Turanskiy
11a795a01c fix: surface upstream error details in Gemini CLI OAuth onboarding UI
SetOAuthSessionError previously sent generic messages to the management
panel (e.g. "Failed to complete Gemini CLI onboarding"), hiding the
actual error returned by Google APIs. The specific error was only
written to the server log via log.Errorf, which is often inaccessible
in headless/Docker deployments.

Include the upstream error in all 8 OAuth error paths so the
management panel shows actionable messages like "no Google Cloud
projects available for this account" instead of a generic failure.
2026-03-06 13:06:37 +03:00
zhongnan.rex
242aecd924 feat(registry): add gemini-3.1-flash-image-preview model definition 2026-03-06 10:50:04 +08:00
hkfires
ce8cc1ba33 fix(translator): pass through adaptive thinking effort 2026-03-06 09:13:32 +08:00
Kirill Turanskiy
97fdd2e088 fix: preserve original JSON bytes in normalizeCacheControlTTL when no TTL change needed
normalizeCacheControlTTL unconditionally re-serializes the entire request
body through json.Unmarshal/json.Marshal even when no TTL normalization
is needed. Go's json.Marshal randomizes map key order and HTML-escapes
<, >, & characters (to \u003c, \u003e, \u0026), producing different raw
bytes on every call.

Anthropic's prompt caching uses byte-prefix matching, so any byte-level
difference causes a cache miss. This means the ~119K system prompt and
tools are re-processed on every request when routed through CPA.

The fix adds a bool return to normalizeTTLForBlock to indicate whether
it actually modified anything, and skips the marshal step in
normalizeCacheControlTTL when no blocks were changed.
2026-03-05 22:28:01 +03:00
Xu Hong
553d6f50ea fix: sanitize tool_use.id to comply with Claude API regex ^[a-zA-Z0-9_-]+$
Add util.SanitizeClaudeToolID() to replace non-conforming characters in
tool_use.id fields across all five response translators (gemini, codex,
openai, antigravity, gemini-cli).

Upstream tool names may contain dots or other special characters
(e.g. "fs.readFile") that violate Claude's ID validation regex.
The sanitizer replaces such characters with underscores and provides
a generated fallback for empty IDs.

Fixes #1872, Fixes #1849

Made-with: Cursor
2026-03-06 00:10:09 +08:00
lyd123qw2008
dd44413ba5 refactor(watcher): make authSliceToMap always return map 2026-03-02 10:09:56 +08:00
lyd123qw2008
10fa0f2062 refactor(watcher): dedupe auth map conversion in incremental flow 2026-03-02 10:03:42 +08:00
lyd123qw2008
30338ecec4 perf(watcher): remove redundant auth clones in incremental path 2026-03-01 14:05:11 +08:00
lyd123qw2008
9a37defed3 test(watcher): restore main test names and max-retry callback coverage 2026-03-01 13:54:03 +08:00
lyd123qw2008
c83a057996 refactor(watcher): make auth file events fully incremental 2026-03-01 13:42:42 +08:00
Alexey Yanchenko
b7588428c5 fix: preserve input_audio content parts when proxying to Antigravity
- Add input_audio handling in chat/completions translator (antigravity_openai_request.go)
- Add input_audio handling in responses translator (gemini_openai-responses_request.go)
- Map OpenAI audio formats (mp3, wav, ogg, flac, aac, webm, pcm16, g711_ulaw, g711_alaw) to correct MIME types for Gemini inlineData
2026-02-23 20:50:28 +07:00
44 changed files with 4783 additions and 713 deletions

117
README.md
View File

@@ -8,123 +8,6 @@ All third-party provider support is maintained by community contributors; CLIPro
The Plus release stays in lockstep with the mainline features.
## Differences from the Mainline
[![z.ai](https://assets.router-for.me/english-5-0.jpg)](https://z.ai/subscribe?ic=8JVLJQFSKB)
## New Features (Plus Enhanced)
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & GLM-5 Only Available for Pro Usersmodel across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
## Kiro Authentication
### CLI Login
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
**AWS Builder ID** (recommended):
```bash
# Device code flow
./CLIProxyAPI --kiro-aws-login
# Authorization code flow
./CLIProxyAPI --kiro-aws-authcode
```
**Import token from Kiro IDE:**
```bash
./CLIProxyAPI --kiro-import
```
To get a token from Kiro IDE:
1. Open Kiro IDE and login with Google (or GitHub)
2. Find the token file: `~/.kiro/kiro-auth-token.json`
3. Run: `./CLIProxyAPI --kiro-import`
**AWS IAM Identity Center (IDC):**
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
# Specify region
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**Additional flags:**
| Flag | Description |
|------|-------------|
| `--no-browser` | Don't open browser automatically, print URL instead |
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
### Web-based OAuth Login
Access the Kiro OAuth web interface at:
```
http://your-server:8080/v0/oauth/kiro
```
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
- AWS Builder ID login
- AWS Identity Center (IDC) login
- Token import from Kiro IDE
## Quick Deployment with Docker
### One-Command Deployment
```bash
# Create deployment directory
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# Create docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: eceasy/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# Download example config
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
# Pull and start
docker compose pull && docker compose up -d
```
### Configuration
Edit `config.yaml` before starting:
```yaml
# Basic configuration example
server:
port: 8317
# Add your provider configurations here
```
### Update to Latest Version
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## Contributing
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.

View File

@@ -8,123 +8,6 @@
该 Plus 版本的主线功能与主线功能强制同步。
## 与主线版本版本差异
[![bigmodel.cn](https://assets.router-for.me/chinese-5-0.jpg)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
## 新增功能 (Plus 增强版)
GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7受限于算力目前仅限Pro用户开放为开发者提供顶尖的编码体验。
智谱AI为本产品提供了特别优惠使用以下链接购买可以享受九折优惠https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
### 命令行登录
> **注意:** 由于 AWS Cognito 限制Google/GitHub 登录不可用于第三方应用。
**AWS Builder ID**(推荐):
```bash
# 设备码流程
./CLIProxyAPI --kiro-aws-login
# 授权码流程
./CLIProxyAPI --kiro-aws-authcode
```
**从 Kiro IDE 导入令牌:**
```bash
./CLIProxyAPI --kiro-import
```
获取令牌步骤:
1. 打开 Kiro IDE使用 Google或 GitHub登录
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
3. 运行:`./CLIProxyAPI --kiro-import`
**AWS IAM Identity Center (IDC)**
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
# 指定区域
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**附加参数:**
| 参数 | 说明 |
|------|------|
| `--no-browser` | 不自动打开浏览器,打印 URL |
| `--no-incognito` | 使用已有浏览器会话Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
| `--kiro-idc-start-url` | IDC Start URL`--kiro-idc-login` 必需) |
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1` |
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
### 网页端 OAuth 登录
访问 Kiro OAuth 网页认证界面:
```
http://your-server:8080/v0/oauth/kiro
```
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
- AWS Builder ID 登录
- AWS Identity Center (IDC) 登录
- 从 Kiro IDE 导入令牌
## Docker 快速部署
### 一键部署
```bash
# 创建部署目录
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# 创建 docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: eceasy/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# 下载示例配置
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
# 拉取并启动
docker compose pull && docker compose up -d
```
### 配置说明
启动前请编辑 `config.yaml`
```yaml
# 基本配置示例
server:
port: 8317
# 在此添加你的供应商配置
```
### 更新到最新版本
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## 贡献
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。

View File

@@ -219,6 +219,17 @@ nonstream-keepalive-interval: 0
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# # You may repeat the same alias to build an internal model pool.
# # The client still sees only one alias in the model list.
# # Requests to that alias will round-robin across the upstream names below,
# # and if the chosen upstream fails before producing output, the request will
# # continue with the next upstream model in the same alias pool.
# - name: "qwen3.5-plus"
# alias: "claude-opus-4.66"
# - name: "glm-5"
# alias: "claude-opus-4.66"
# - name: "kimi-k2.5"
# alias: "claude-opus-4.66"
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
# vertex-api-key:

View File

@@ -1312,12 +1312,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
if errAll != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
return
}
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
return
}
ts.ProjectID = strings.Join(projects, ",")
@@ -1326,7 +1326,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ts.Auto = false
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
SetOAuthSessionError(state, "Google One auto-discovery failed")
SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
return
}
if strings.TrimSpace(ts.ProjectID) == "" {
@@ -1337,19 +1337,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the auto-discovered project")
SetOAuthSessionError(state, "Cloud AI API not enabled")
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return
}
} else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
return
}
@@ -1362,13 +1362,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the selected project")
SetOAuthSessionError(state, "Cloud AI API not enabled")
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return
}
}

View File

@@ -211,6 +211,21 @@ func GetGeminiModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-flash-image-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-flash-image-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Flash Image Preview",
Description: "Gemini 3.1 Flash Image Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -351,6 +366,17 @@ func GetGeminiVertexModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-flash-image-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-flash-image-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Flash Image Preview",
Description: "Gemini 3.1 Flash Image Preview",
},
{
ID: "gemini-3.1-flash-lite-preview",
Object: "model",

View File

@@ -64,6 +64,11 @@ type ModelInfo struct {
UserDefined bool `json:"-"`
}
type availableModelsCacheEntry struct {
models []map[string]any
expiresAt time.Time
}
// ThinkingSupport describes a model family's supported internal reasoning budget range.
// Values are interpreted in provider-native token units.
type ThinkingSupport struct {
@@ -118,6 +123,8 @@ type ModelRegistry struct {
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
availableModelsCache map[string]availableModelsCacheEntry
// hook is an optional callback sink for model registration changes
hook ModelRegistryHook
}
@@ -130,15 +137,28 @@ var registryOnce sync.Once
func GetGlobalRegistry() *ModelRegistry {
registryOnce.Do(func() {
globalRegistry = &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
availableModelsCache: make(map[string]availableModelsCacheEntry),
mutex: &sync.RWMutex{},
}
})
return globalRegistry
}
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
if r.availableModelsCache == nil {
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
}
}
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
if len(r.availableModelsCache) == 0 {
return
}
clear(r.availableModelsCache)
}
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
@@ -153,9 +173,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
}
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
return info
return cloneModelInfo(info)
}
return LookupStaticModelInfo(modelID)
return cloneModelInfo(LookupStaticModelInfo(modelID))
}
// SetHook sets an optional hook for observing model registration changes.
@@ -213,6 +233,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
provider := strings.ToLower(clientProvider)
uniqueModelIDs := make([]string, 0, len(models))
@@ -238,6 +259,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientModels, clientID)
delete(r.clientModelInfos, clientID)
delete(r.clientProviders, clientID)
r.invalidateAvailableModelsCacheLocked()
misc.LogCredentialSeparator()
return
}
@@ -265,6 +287,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
} else {
delete(r.clientProviders, clientID)
}
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator()
@@ -408,6 +431,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientProviders, clientID)
}
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
@@ -511,6 +535,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedOutputModalities) > 0 {
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
}
if model.Thinking != nil {
copyThinking := *model.Thinking
if len(model.Thinking.Levels) > 0 {
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
}
copyModel.Thinking = &copyThinking
}
return &copyModel
}
@@ -540,6 +571,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.unregisterClientInternal(clientID)
r.invalidateAvailableModelsCacheLocked()
}
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
@@ -606,9 +638,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists {
registration.QuotaExceededClients[clientID] = new(time.Now())
now := time.Now()
registration.QuotaExceededClients[clientID] = &now
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
}
}
@@ -620,9 +655,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists {
delete(registration.QuotaExceededClients, clientID)
r.invalidateAvailableModelsCacheLocked()
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
}
}
@@ -638,6 +675,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID]
if !exists || registration == nil {
@@ -651,6 +689,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
}
registration.SuspendedClients[clientID] = reason
registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
if reason != "" {
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
} else {
@@ -668,6 +707,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID]
if !exists || registration == nil || registration.SuspendedClients == nil {
@@ -678,6 +718,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
}
delete(registration.SuspendedClients, clientID)
registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Resumed client %s for model %s", clientID, modelID)
}
@@ -713,22 +754,52 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
// Returns:
// - []map[string]any: List of available models in the requested format
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
r.mutex.RLock()
defer r.mutex.RUnlock()
now := time.Now()
models := make([]map[string]any, 0)
r.mutex.RLock()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
models := cloneModelMaps(cache.models)
r.mutex.RUnlock()
return models
}
r.mutex.RUnlock()
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
return cloneModelMaps(cache.models)
}
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
models: cloneModelMaps(models),
expiresAt: expiresAt,
}
return models
}
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
models := make([]map[string]any, 0, len(r.models))
quotaExpiredDuration := 5 * time.Minute
var expiresAt time.Time
for _, registration := range r.models {
// Check if model has any non-quota-exceeded clients
availableClients := registration.Count
now := time.Now()
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime == nil {
continue
}
recoveryAt := quotaTime.Add(quotaExpiredDuration)
if now.Before(recoveryAt) {
expiredClients++
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
expiresAt = recoveryAt
}
}
}
@@ -749,7 +820,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
effectiveClients = 0
}
// Include models that have available clients, or those solely cooling down.
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
model := r.convertModelToMap(registration.Info, handlerType)
if model != nil {
@@ -758,7 +828,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
}
}
return models
return models, expiresAt
}
func cloneModelMaps(models []map[string]any) []map[string]any {
cloned := make([]map[string]any, 0, len(models))
for _, model := range models {
if model == nil {
cloned = append(cloned, nil)
continue
}
copyModel := make(map[string]any, len(model))
for key, value := range model {
copyModel[key] = cloneModelMapValue(value)
}
cloned = append(cloned, copyModel)
}
return cloned
}
func cloneModelMapValue(value any) any {
switch typed := value.(type) {
case map[string]any:
copyMap := make(map[string]any, len(typed))
for key, entry := range typed {
copyMap[key] = cloneModelMapValue(entry)
}
return copyMap
case []any:
copySlice := make([]any, len(typed))
for i, entry := range typed {
copySlice[i] = cloneModelMapValue(entry)
}
return copySlice
case []string:
return append([]string(nil), typed...)
default:
return value
}
}
// GetAvailableModelsByProvider returns models available for the given provider identifier.
@@ -874,11 +981,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil {
result = append(result, entry.info)
result = append(result, cloneModelInfo(entry.info))
continue
}
if ok && registration != nil && registration.Info != nil {
result = append(result, registration.Info)
result = append(result, cloneModelInfo(registration.Info))
}
}
}
@@ -987,13 +1094,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
if reg.Providers != nil {
if count, ok := reg.Providers[provider]; ok && count > 0 {
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
return info
return cloneModelInfo(info)
}
}
}
}
// Fallback to global info (last registered)
return reg.Info
return cloneModelInfo(reg.Info)
}
return nil
}
@@ -1033,7 +1140,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
result["max_completion_tokens"] = model.MaxCompletionTokens
}
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
}
if len(model.SupportedEndpoints) > 0 {
result["supported_endpoints"] = model.SupportedEndpoints
@@ -1094,13 +1201,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
result["outputTokenLimit"] = model.OutputTokenLimit
}
if len(model.SupportedGenerationMethods) > 0 {
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedInputModalities) > 0 {
result["supportedInputModalities"] = model.SupportedInputModalities
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
}
if len(model.SupportedOutputModalities) > 0 {
result["supportedOutputModalities"] = model.SupportedOutputModalities
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
}
return result
@@ -1130,15 +1237,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
invalidated := false
for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
delete(registration.QuotaExceededClients, clientID)
invalidated = true
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
}
}
}
if invalidated {
r.invalidateAvailableModelsCacheLocked()
}
}
// GetFirstAvailableModel returns the first available model for the given handler type.
@@ -1152,8 +1264,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
// - string: The model ID of the first available model, or empty string if none available
// - error: An error if no models are available
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Get all available models for this handler type
models := r.GetAvailableModels(handlerType)
@@ -1213,13 +1323,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
// Prefer client's own model info to preserve original type/owned_by
if clientInfos != nil {
if info, ok := clientInfos[modelID]; ok && info != nil {
result = append(result, info)
result = append(result, cloneModelInfo(info))
continue
}
}
// Fallback to global registry (for backwards compatibility)
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
result = append(result, reg.Info)
result = append(result, cloneModelInfo(reg.Info))
}
}
return result

View File

@@ -0,0 +1,54 @@
package registry
import "testing"
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected 1 model, got %d", len(first))
}
first[0]["id"] = "mutated"
first[0]["display_name"] = "Mutated"
second := r.GetAvailableModels("openai")
if got := second[0]["id"]; got != "m1" {
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
}
if got := second[0]["display_name"]; got != "Model One" {
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
}
}
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected 1 model, got %d", len(models))
}
if got := models[0]["display_name"]; got != "Model One" {
t.Fatalf("expected initial display_name Model One, got %v", got)
}
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
models = r.GetAvailableModels("openai")
if got := models[0]["display_name"]; got != "Model One Updated" {
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
}
r.SuspendClientModel("client-1", "m1", "manual")
models = r.GetAvailableModels("openai")
if len(models) != 0 {
t.Fatalf("expected no available models after suspension, got %d", len(models))
}
r.ResumeClientModel("client-1", "m1")
models = r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to reappear after resume, got %d", len(models))
}
}

View File

@@ -0,0 +1,149 @@
package registry
import (
"testing"
"time"
)
func TestGetModelInfoReturnsClone(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
}})
first := r.GetModelInfo("m1", "gemini")
if first == nil {
t.Fatal("expected model info")
}
first.DisplayName = "mutated"
first.Thinking.Levels[0] = "mutated"
second := r.GetModelInfo("m1", "gemini")
if second.DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
}
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
}
}
func TestGetModelsForClientReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetModelsForClient("client-1")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetModelsForClient("client-1")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetAvailableModelsByProvider("gemini")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetAvailableModelsByProvider("gemini")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
r.SetModelQuotaExceeded("client-1", "m1")
if models := r.GetAvailableModels("openai"); len(models) != 1 {
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
}
r.mutex.Lock()
quotaTime := time.Now().Add(-6 * time.Minute)
r.models["m1"].QuotaExceededClients["client-1"] = &quotaTime
r.mutex.Unlock()
r.CleanupExpiredQuotas()
if count := r.GetModelCount("m1"); count != 1 {
t.Fatalf("expected model count 1 after cleanup, got %d", count)
}
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
}
if got := models[0]["id"]; got != "m1" {
t.Fatalf("expected model id m1, got %v", got)
}
}
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
SupportedParameters: []string{"temperature", "top_p"},
}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected one model, got %d", len(first))
}
params, ok := first[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 {
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
}
params[0] = "mutated"
second := r.GetAvailableModels("openai")
params, ok = second[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 || params[0] != "temperature" {
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
}
}
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
first := LookupModelInfo("glm-4.6")
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
t.Fatalf("expected static model with thinking levels, got %+v", first)
}
first.Thinking.Levels[0] = "mutated"
second := LookupModelInfo("glm-4.6")
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
t.Fatalf("expected static lookup clone, got %+v", second)
}
}

View File

@@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
}
return true
})
} else if system.Type == gjson.String && system.String() != "" {
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
partJSON, _ = sjson.Set(partJSON, "text", system.String())
result += "," + partJSON
}
result += "]"
@@ -1485,25 +1489,27 @@ func countCacheControlsMap(root map[string]any) int {
return count
}
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
ccRaw, exists := obj["cache_control"]
if !exists {
return
return false
}
cc, ok := asObject(ccRaw)
if !ok {
*seen5m = true
return
return false
}
ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true
return
return false
}
if *seen5m {
delete(cc, "ttl")
return true
}
return false
}
func findLastCacheControlIndex(arr []any) int {
@@ -1599,11 +1605,14 @@ func normalizeCacheControlTTL(payload []byte) []byte {
}
seen5m := false
modified := false
if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools {
if obj, ok := asObject(tool); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
@@ -1611,7 +1620,9 @@ func normalizeCacheControlTTL(payload []byte) []byte {
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
@@ -1628,12 +1639,17 @@ func normalizeCacheControlTTL(payload []byte) []byte {
}
for _, item := range content {
if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
}
if !modified {
return payload
}
return marshalPayloadObject(payload, root)
}

View File

@@ -369,6 +369,19 @@ func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
}
}
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
// Payload where no TTL normalization is needed (all blocks use 1h with no
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
// that json.Marshal would escape to \u003c etc., altering byte identity.
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"<system-reminder>foo & bar</system-reminder>","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := normalizeCacheControlTTL(payload)
if !bytes.Equal(out, payload) {
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{
"tools": [
@@ -967,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
}
}
// Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
system := gjson.GetBytes(out, "system")
if !system.IsArray() {
t.Fatalf("system should be an array, got %s", system.Type)
}
blocks := system.Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
}
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
}
if blocks[2].Get("text").String() != "You are a helpful assistant." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
}
}
// Test case 2: Strict mode drops the string system prompt
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, true)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 3: Empty string system prompt does not produce a spurious block
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 4: Array system prompt is unaffected by the string handling
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != "Be concise." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
}
// Test case 5: Special characters in string system prompt survive conversion
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
}
}

View File

@@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
@@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
// Imagen models don't support streaming, skip SSE params
@@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")

View File

@@ -522,9 +522,9 @@ func detectLastConversationRole(body []byte) string {
}
switch item.Get("type").String() {
case "function_call", "function_call_arguments":
case "function_call", "function_call_arguments", "computer_call":
return "assistant"
case "function_call_output", "function_call_response", "tool_result":
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
return "tool"
}
}
@@ -832,6 +832,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
if tools.IsArray() {
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
if isGitHubCopilotResponsesBuiltinTool(toolType) {
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
continue
}
// Accept OpenAI format (type="function") and Claude format
// (no type field, but has top-level name + input_schema).
if toolType != "" && toolType != "function" {
@@ -879,6 +883,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
}
if toolChoice.Type == gjson.JSON {
choiceType := toolChoice.Get("type").String()
if isGitHubCopilotResponsesBuiltinTool(choiceType) {
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(toolChoice.Raw))
return body
}
if choiceType == "function" {
name := toolChoice.Get("name").String()
if name == "" {
@@ -896,6 +904,15 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
return body
}
func isGitHubCopilotResponsesBuiltinTool(toolType string) bool {
switch strings.TrimSpace(toolType) {
case "computer", "computer_use_preview":
return true
default:
return false
}
}
func collectTextFromNode(node gjson.Result) string {
if !node.Exists() {
return ""

View File

@@ -257,7 +257,10 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma
if suffixResult.HasSuffix {
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
} else {
config = extractThinkingConfig(body, toFormat)
config = extractThinkingConfig(body, fromFormat)
if !hasThinkingConfig(config) && fromFormat != toFormat {
config = extractThinkingConfig(body, toFormat)
}
}
if !hasThinkingConfig(config) {
@@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
if config.Mode != ModeLevel {
return config
}
if toFormat == "claude" {
return config
}
if !isBudgetCapableProvider(toFormat) {
return config
}

View File

@@ -0,0 +1,55 @@
package thinking_test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
"github.com/tidwall/gjson"
)
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-user-defined-claude-" + t.Name()
modelID := "custom-claude-4-6"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
t.Cleanup(func() {
reg.UnregisterClient(clientID)
})
tests := []struct {
name string
model string
body []byte
}{
{
name: "claude adaptive effort body",
model: modelID,
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
},
{
name: "suffix level",
model: modelID + "(high)",
body: []byte(`{}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
if err != nil {
t.Fatalf("ApplyThinking() error = %v", err)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
}
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
}
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
}
})
}
}

View File

@@ -477,9 +477,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
if effort == "max" {
effort = "high"
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")

View File

@@ -1235,64 +1235,3 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
}
}
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) {
tests := []struct {
name string
effort string
expected string
}{
{"low", "low", "low"},
{"medium", "medium", "medium"},
{"high", "high", "high"},
{"max", "max", "high"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"thinking": {"type": "adaptive"},
"output_config": {"effort": "` + tt.effort + `"}
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
outputStr := string(output)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if !thinkingConfig.Exists() {
t.Fatal("thinkingConfig should exist for adaptive thinking")
}
if thinkingConfig.Get("thinkingLevel").String() != tt.expected {
t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
})
}
}
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"thinking": {"type": "adaptive"}
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
outputStr := string(output)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if !thinkingConfig.Exists() {
t.Fatal("thinkingConfig should exist for adaptive thinking without effort")
}
if thinkingConfig.Get("thinkingLevel").String() != "high" {
t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
}

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -256,7 +257,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)

View File

@@ -212,6 +212,33 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else {
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
}
case "input_audio":
audioData := item.Get("input_audio.data").String()
audioFormat := item.Get("input_audio.format").String()
if audioData != "" {
audioMimeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"webm": "audio/webm",
"pcm16": "audio/pcm",
"g711_ulaw": "audio/basic",
"g711_alaw": "audio/basic",
}
mimeType := "audio/wav"
if audioFormat != "" {
if mapped, ok := audioMimeMap[audioFormat]; ok {
mimeType = mapped
} else {
mimeType = "audio/" + audioFormat
}
}
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData)
p++
}
}
}
}

View File

@@ -203,46 +203,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
msg, _ = sjson.SetRaw(msg, "content.-1", part)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
case "image_url":
// Convert OpenAI image format to Claude Code format
imageURL := part.Get("image_url.url").String()
if strings.HasPrefix(imageURL, "data:") {
// Extract base64 data and media type from data URL
parts := strings.Split(imageURL, ",")
if len(parts) == 2 {
mediaTypePart := strings.Split(parts[0], ";")[0]
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
data := parts[1]
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
}
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
}
return true
})
@@ -291,11 +254,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "tool":
// Handle tool result messages conversion
toolCallID := message.Get("tool_call_id").String()
content := message.Get("content").String()
toolContentResult := message.Get("content")
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
msg, _ = sjson.Set(msg, "content.0.content", content)
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
if toolResultContentRaw {
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
} else {
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
messageIndex++
}
@@ -358,3 +326,110 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
return []byte(out)
}
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
switch part.Get("type").String() {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
return textPart
case "image_url":
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
return docPart
}
}
}
return ""
}
func convertOpenAIImageURLToClaudePart(imageURL string) string {
if imageURL == "" {
return ""
}
if strings.HasPrefix(imageURL, "data:") {
parts := strings.SplitN(imageURL, ",", 2)
if len(parts) != 2 {
return ""
}
mediaTypePart := strings.SplitN(parts[0], ";", 2)[0]
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
if mediaType == "" {
mediaType = "application/octet-stream"
}
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
return imagePart
}
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
return imagePart
}
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
if !content.Exists() {
return "", false
}
if content.Type == gjson.String {
return content.String(), false
}
if content.IsArray() {
claudeContent := "[]"
partCount := 0
content.ForEach(func(_, part gjson.Result) bool {
if part.Type == gjson.String {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.String())
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
partCount++
return true
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
partCount++
}
return true
})
if partCount > 0 || len(content.Array()) == 0 {
return claudeContent, true
}
return content.Raw, false
}
if content.IsObject() {
claudePart := convertOpenAIContentPartToClaudePart(content)
if claudePart != "" {
claudeContent := "[]"
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
return claudeContent, true
}
return content.Raw, false
}
return content.Raw, false
}

View File

@@ -0,0 +1,137 @@
package chat_completions
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "do_work",
"arguments": "{\"a\":1}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [
{"type": "text", "text": "tool ok"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolResult := messages[1].Get("content.0")
if got := toolResult.Get("type").String(); got != "tool_result" {
t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got)
}
if got := toolResult.Get("tool_use_id").String(); got != "call_1" {
t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got)
}
toolContent := toolResult.Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "text" {
t.Fatalf("Expected first tool_result part type %q, got %q", "text", got)
}
if got := toolContent.Get("0.text").String(); got != "tool ok" {
t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got)
}
if got := toolContent.Get("1.type").String(); got != "image" {
t.Fatalf("Expected second tool_result part type %q, got %q", "image", got)
}
if got := toolContent.Get("1.source.type").String(); got != "base64" {
t.Fatalf("Expected image source type %q, got %q", "base64", got)
}
if got := toolContent.Get("1.source.media_type").String(); got != "image/png" {
t.Fatalf("Expected image media type %q, got %q", "image/png", got)
}
if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" {
t.Fatalf("Unexpected base64 image data: %q", got)
}
}
func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "do_work",
"arguments": "{\"a\":1}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://example.com/tool.png"
}
}
]
}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content.0.content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "image" {
t.Fatalf("Expected tool_result part type %q, got %q", "image", got)
}
if got := toolContent.Get("0.source.type").String(); got != "url" {
t.Fatalf("Expected image source type %q, got %q", "url", got)
}
if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" {
t.Fatalf("Unexpected image URL: %q", got)
}
}

View File

@@ -12,6 +12,7 @@ import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -141,7 +142,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{
// Restore original tool name if shortened
name := itemResult.Get("name").String()
@@ -310,7 +311,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {

View File

@@ -14,6 +14,7 @@ import (
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -209,7 +210,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)

View File

@@ -224,7 +224,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", clientToolName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -343,7 +343,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {

View File

@@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
content := m.Get("content")
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> system_instruction as a user message style
// system -> systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String())
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
systemPartIndex++
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String())
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
systemPartIndex++
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
systemPartIndex++
}
}

View File

@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if instructions := root.Get("instructions"); instructions.Exists() {
systemInstr := `{"parts":[{"text":""}]}`
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
}
// Convert input messages to Gemini contents format
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if strings.EqualFold(itemRole, "system") {
if contentArray := item.Get("content"); contentArray.Exists() {
systemInstr := ""
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
systemInstr = systemInstructionResult.Raw
} else {
systemInstr = `{"parts":[]}`
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
}
if systemInstr != `{"parts":[]}` {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
}
}
continue
@@ -237,6 +237,33 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
}
}
case "input_audio":
audioData := contentItem.Get("data").String()
audioFormat := contentItem.Get("format").String()
if audioData != "" {
audioMimeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"webm": "audio/webm",
"pcm16": "audio/pcm",
"g711_ulaw": "audio/basic",
"g711_alaw": "audio/basic",
}
mimeType := "audio/wav"
if audioFormat != "" {
if mapped, ok := audioMimeMap[audioFormat]; ok {
mimeType = mapped
} else {
mimeType = "audio/" + audioFormat
}
}
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
partJSON, _ = sjson.Set(partJSON, "inline_data.data", audioData)
}
}
if partJSON != "" {

View File

@@ -183,7 +183,12 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
toolResultContent, toolResultContentRaw := convertClaudeToolResultContent(part.Get("content"))
if toolResultContentRaw {
toolResultJSON, _ = sjson.SetRaw(toolResultJSON, "content", toolResultContent)
} else {
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", toolResultContent)
}
toolResults = append(toolResults, toolResultJSON)
}
return true
@@ -374,21 +379,41 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
}
}
func convertClaudeToolResultContentToString(content gjson.Result) string {
func convertClaudeToolResultContent(content gjson.Result) (string, bool) {
if !content.Exists() {
return ""
return "", false
}
if content.Type == gjson.String {
return content.String()
return content.String(), false
}
if content.IsArray() {
var parts []string
contentJSON := "[]"
hasImagePart := false
content.ForEach(func(_, item gjson.Result) bool {
switch {
case item.Type == gjson.String:
parts = append(parts, item.String())
text := item.String()
parts = append(parts, text)
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
case item.IsObject() && item.Get("type").String() == "text":
text := item.Get("text").String()
parts = append(parts, text)
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
case item.IsObject() && item.Get("type").String() == "image":
contentItem, ok := convertClaudeContentPart(item)
if ok {
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
hasImagePart = true
} else {
parts = append(parts, item.Raw)
}
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
parts = append(parts, item.Get("text").String())
default:
@@ -397,19 +422,31 @@ func convertClaudeToolResultContentToString(content gjson.Result) string {
return true
})
if hasImagePart {
return contentJSON, true
}
joined := strings.Join(parts, "\n\n")
if strings.TrimSpace(joined) != "" {
return joined
return joined, false
}
return content.Raw
return content.Raw, false
}
if content.IsObject() {
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
if content.Get("type").String() == "image" {
contentItem, ok := convertClaudeContentPart(content)
if ok {
contentJSON := "[]"
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
return contentJSON, true
}
}
return content.Raw
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String(), false
}
return content.Raw, false
}
return content.Raw
return content.Raw, false
}

View File

@@ -488,6 +488,114 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultTextAndImageContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_1",
"content": [
{"type": "text", "text": "tool ok"},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "text" {
t.Fatalf("Expected first tool content type %q, got %q", "text", got)
}
if got := toolContent.Get("0.text").String(); got != "tool ok" {
t.Fatalf("Expected first tool content text %q, got %q", "tool ok", got)
}
if got := toolContent.Get("1.type").String(); got != "image_url" {
t.Fatalf("Expected second tool content type %q, got %q", "image_url", got)
}
if got := toolContent.Get("1.image_url.url").String(); got != "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" {
t.Fatalf("Unexpected image_url: %q", got)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultURLImageOnly(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_1",
"content": {
"type": "image",
"source": {
"type": "url",
"url": "https://example.com/tool.png"
}
}
}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "image_url" {
t.Fatalf("Expected tool content type %q, got %q", "image_url", got)
}
if got := toolContent.Get("0.image_url.url").String(); got != "https://example.com/tool.png" {
t.Fatalf("Unexpected image_url: %q", got)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",

View File

@@ -243,7 +243,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send content_block_start for tool_use
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
}
@@ -414,7 +414,7 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
@@ -612,7 +612,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
toolCalls.ForEach(func(_, tc gjson.Result) bool {
hasToolCall = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String()))
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
argsStr := util.FixJSON(tc.Get("function.arguments").String())
@@ -669,7 +669,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
hasToolCall = true
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())

View File

@@ -0,0 +1,24 @@
package util
import (
"fmt"
"regexp"
"sync/atomic"
"time"
)
var (
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
claudeToolUseIDCounter uint64
)
// SanitizeClaudeToolID ensures the given id conforms to Claude's
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
// replaced with '_'; an empty result gets a generated fallback.
func SanitizeClaudeToolID(id string) string {
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
if s == "" {
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
}
return s
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
@@ -75,6 +76,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
w.lastAuthHashes = make(map[string]string)
w.lastAuthContents = make(map[string]*coreauth.Auth)
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
} else if resolvedAuthDir != "" {
@@ -92,6 +94,17 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
if errParse := json.Unmarshal(data, &auth); errParse == nil {
w.lastAuthContents[normalizedPath] = &auth
}
ctx := &synthesizer.SynthesisContext{
Config: cfg,
AuthDir: resolvedAuthDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
w.fileAuthsByPath[normalizedPath] = pathAuths
}
}
}
}
return nil
@@ -143,13 +156,14 @@ func (w *Watcher) addOrUpdateClient(path string) {
}
w.clientsMutex.Lock()
cfg := w.config
if cfg == nil {
if w.config == nil {
log.Error("config is nil, cannot add or update client")
w.clientsMutex.Unlock()
return
}
if w.fileAuthsByPath == nil {
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
}
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
w.clientsMutex.Unlock()
@@ -177,34 +191,86 @@ func (w *Watcher) addOrUpdateClient(path string) {
}
w.lastAuthContents[normalized] = &newAuth
w.clientsMutex.Unlock() // Unlock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after add/update")
w.triggerServerUpdate(cfg)
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
for id, a := range w.fileAuthsByPath[normalized] {
oldByID[id] = a
}
// Build synthesized auth entries for this single file only.
sctx := &synthesizer.SynthesisContext{
Config: w.config,
AuthDir: w.authDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
newByID := authSliceToMap(generated)
if len(newByID) > 0 {
w.fileAuthsByPath[normalized] = newByID
} else {
delete(w.fileAuthsByPath, normalized)
}
updates := w.computePerPathUpdatesLocked(oldByID, newByID)
w.clientsMutex.Unlock()
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) removeClient(path string) {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
for id, a := range w.fileAuthsByPath[normalized] {
oldByID[id] = a
}
delete(w.lastAuthHashes, normalized)
delete(w.lastAuthContents, normalized)
delete(w.fileAuthsByPath, normalized)
w.clientsMutex.Unlock() // Release the lock before the callback
updates := w.computePerPathUpdatesLocked(oldByID, map[string]*coreauth.Auth{})
w.clientsMutex.Unlock()
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after removal")
w.triggerServerUpdate(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate {
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
updates := make([]AuthUpdate, 0, len(oldByID)+len(newByID))
for id, newAuth := range newByID {
existing, ok := w.currentAuths[id]
if !ok {
w.currentAuths[id] = newAuth.Clone()
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: newAuth.Clone()})
continue
}
if !authEqual(existing, newAuth) {
w.currentAuths[id] = newAuth.Clone()
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: newAuth.Clone()})
}
}
for id := range oldByID {
if _, stillExists := newByID[id]; stillExists {
continue
}
delete(w.currentAuths, id)
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
}
return updates
}
func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
byID := make(map[string]*coreauth.Auth, len(auths))
for _, a := range auths {
if a == nil || strings.TrimSpace(a.ID) == "" {
continue
}
byID[a.ID] = a
}
return byID
}
func (w *Watcher) loadFileClients(cfg *config.Config) int {

View File

@@ -14,6 +14,8 @@ import (
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
var snapshotCoreAuthsFunc = snapshotCoreAuths
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
@@ -76,7 +78,11 @@ func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
}
func (w *Watcher) refreshAuthState(force bool) {
auths := w.SnapshotCoreAuths()
w.clientsMutex.RLock()
cfg := w.config
authDir := w.authDir
w.clientsMutex.RUnlock()
auths := snapshotCoreAuthsFunc(cfg, authDir)
w.clientsMutex.Lock()
if len(w.runtimeAuths) > 0 {
for _, a := range w.runtimeAuths {

View File

@@ -36,9 +36,6 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
return out, nil
}
now := ctx.Now
cfg := ctx.Config
for _, e := range entries {
if e.IsDir() {
continue
@@ -52,99 +49,120 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
if errRead != nil || len(data) == 0 {
continue
}
var metadata map[string]any
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
auths := synthesizeFileAuths(ctx, full, data)
if len(auths) == 0 {
continue
}
t, _ := metadata["type"].(string)
if t == "" {
continue
}
provider := strings.ToLower(t)
if provider == "gemini" {
provider = "gemini-cli"
}
label := provider
if email, _ := metadata["email"].(string); email != "" {
label = email
}
// Use relative path under authDir as ID to stay consistent with the file-based token store
id := full
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
id = rel
}
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
proxyURL := ""
if p, ok := metadata["proxy_url"].(string); ok {
proxyURL = p
}
prefix := ""
if rawPrefix, ok := metadata["prefix"].(string); ok {
trimmed := strings.TrimSpace(rawPrefix)
trimmed = strings.Trim(trimmed, "/")
if trimmed != "" && !strings.Contains(trimmed, "/") {
prefix = trimmed
}
}
disabled, _ := metadata["disabled"].(bool)
status := coreauth.StatusActive
if disabled {
status = coreauth.StatusDisabled
}
// Read per-account excluded models from the OAuth JSON file
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
a := &coreauth.Auth{
ID: id,
Provider: provider,
Label: label,
Prefix: prefix,
Status: status,
Disabled: disabled,
Attributes: map[string]string{
"source": full,
"path": full,
},
ProxyURL: proxyURL,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
}
// Read priority from auth file
if rawPriority, ok := metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
a.Attributes["priority"] = strconv.Itoa(int(v))
case string:
priority := strings.TrimSpace(v)
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
a.Attributes["priority"] = priority
}
}
}
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
}
out = append(out, a)
out = append(out, virtuals...)
continue
}
}
out = append(out, a)
out = append(out, auths...)
}
return out, nil
}
// SynthesizeAuthFile generates Auth entries for one auth JSON file payload.
// It shares exactly the same mapping behavior as FileSynthesizer.Synthesize.
func SynthesizeAuthFile(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
return synthesizeFileAuths(ctx, fullPath, data)
}
func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
if ctx == nil || len(data) == 0 {
return nil
}
now := ctx.Now
cfg := ctx.Config
var metadata map[string]any
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
return nil
}
t, _ := metadata["type"].(string)
if t == "" {
return nil
}
provider := strings.ToLower(t)
if provider == "gemini" {
provider = "gemini-cli"
}
label := provider
if email, _ := metadata["email"].(string); email != "" {
label = email
}
// Use relative path under authDir as ID to stay consistent with the file-based token store.
id := fullPath
if strings.TrimSpace(ctx.AuthDir) != "" {
if rel, errRel := filepath.Rel(ctx.AuthDir, fullPath); errRel == nil && rel != "" {
id = rel
}
}
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
proxyURL := ""
if p, ok := metadata["proxy_url"].(string); ok {
proxyURL = p
}
prefix := ""
if rawPrefix, ok := metadata["prefix"].(string); ok {
trimmed := strings.TrimSpace(rawPrefix)
trimmed = strings.Trim(trimmed, "/")
if trimmed != "" && !strings.Contains(trimmed, "/") {
prefix = trimmed
}
}
disabled, _ := metadata["disabled"].(bool)
status := coreauth.StatusActive
if disabled {
status = coreauth.StatusDisabled
}
// Read per-account excluded models from the OAuth JSON file.
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
a := &coreauth.Auth{
ID: id,
Provider: provider,
Label: label,
Prefix: prefix,
Status: status,
Disabled: disabled,
Attributes: map[string]string{
"source": fullPath,
"path": fullPath,
},
ProxyURL: proxyURL,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
}
// Read priority from auth file.
if rawPriority, ok := metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
a.Attributes["priority"] = strconv.Itoa(int(v))
case string:
priority := strings.TrimSpace(v)
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
a.Attributes["priority"] = priority
}
}
}
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
}
out := make([]*coreauth.Auth, 0, 1+len(virtuals))
out = append(out, a)
out = append(out, virtuals...)
return out
}
}
return []*coreauth.Auth{a}
}
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
// It disables the primary auth and creates one virtual auth per project.
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {

View File

@@ -45,6 +45,7 @@ type Watcher struct {
watcher *fsnotify.Watcher
lastAuthHashes map[string]string
lastAuthContents map[string]*coreauth.Auth
fileAuthsByPath map[string]map[string]*coreauth.Auth
lastRemoveTimes map[string]time.Time
lastConfigHash string
authQueue chan<- AuthUpdate
@@ -92,11 +93,12 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
return nil, errNewWatcher
}
w := &Watcher{
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
lastAuthHashes: make(map[string]string),
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
lastAuthHashes: make(map[string]string),
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
}
w.dispatchCond = sync.NewCond(&w.dispatchMu)
if store := sdkAuth.GetTokenStore(); store != nil {

View File

@@ -406,8 +406,8 @@ func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
w.addOrUpdateClient(authFile)
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected reload callback once, got %d", got)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload callback for auth update, got %d", got)
}
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
normalized := w.normalizeAuthPath(authFile)
@@ -436,8 +436,110 @@ func TestRemoveClientRemovesHash(t *testing.T) {
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected hash to be removed after deletion")
}
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected reload callback once, got %d", got)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload callback for auth removal, got %d", got)
}
}
func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "sample.json")
if err := os.WriteFile(authFile, []byte(`{"type":"codex","email":"u@example.com"}`), 0o644); err != nil {
t.Fatalf("failed to create auth file: %v", err)
}
origSnapshot := snapshotCoreAuthsFunc
var snapshotCalls int32
snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth {
atomic.AddInt32(&snapshotCalls, 1)
return origSnapshot(cfg, authDir)
}
defer func() { snapshotCoreAuthsFunc = origSnapshot }()
w := &Watcher{
authDir: tmpDir,
lastAuthHashes: make(map[string]string),
lastAuthContents: make(map[string]*coreauth.Auth),
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
}
w.SetConfig(&config.Config{AuthDir: tmpDir})
w.addOrUpdateClient(authFile)
w.removeClient(authFile)
if got := atomic.LoadInt32(&snapshotCalls); got != 0 {
t.Fatalf("expected auth file events to avoid full snapshot, got %d calls", got)
}
}
func TestAuthSliceToMap(t *testing.T) {
t.Parallel()
valid1 := &coreauth.Auth{ID: "a"}
valid2 := &coreauth.Auth{ID: "b"}
dupOld := &coreauth.Auth{ID: "dup", Label: "old"}
dupNew := &coreauth.Auth{ID: "dup", Label: "new"}
empty := &coreauth.Auth{ID: " "}
tests := []struct {
name string
in []*coreauth.Auth
want map[string]*coreauth.Auth
}{
{
name: "nil input",
in: nil,
want: map[string]*coreauth.Auth{},
},
{
name: "empty input",
in: []*coreauth.Auth{},
want: map[string]*coreauth.Auth{},
},
{
name: "filters invalid auths",
in: []*coreauth.Auth{nil, empty},
want: map[string]*coreauth.Auth{},
},
{
name: "keeps valid auths",
in: []*coreauth.Auth{valid1, nil, valid2},
want: map[string]*coreauth.Auth{"a": valid1, "b": valid2},
},
{
name: "last duplicate wins",
in: []*coreauth.Auth{dupOld, dupNew},
want: map[string]*coreauth.Auth{"dup": dupNew},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := authSliceToMap(tc.in)
if len(tc.want) == 0 {
if got == nil {
t.Fatal("expected empty map, got nil")
}
if len(got) != 0 {
t.Fatalf("expected empty map, got %#v", got)
}
return
}
if len(got) != len(tc.want) {
t.Fatalf("unexpected map length: got %d, want %d", len(got), len(tc.want))
}
for id, wantAuth := range tc.want {
gotAuth, ok := got[id]
if !ok {
t.Fatalf("missing id %q in result map", id)
}
if !authEqual(gotAuth, wantAuth) {
t.Fatalf("unexpected auth for id %q: got %#v, want %#v", id, gotAuth, wantAuth)
}
}
})
}
}
@@ -695,8 +797,8 @@ func TestHandleEventRemovesAuthFile(t *testing.T) {
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected reload callback once, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected no reload callback for auth removal, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected hash entry to be removed")
@@ -893,8 +995,8 @@ func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected auth write to trigger reload callback, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected auth write to avoid global reload callback, got %d", reloads)
}
}
@@ -990,8 +1092,8 @@ func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected changed atomic replace to avoid global reload, got %d", reloads)
}
}
@@ -1045,8 +1147,8 @@ func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected known remove to trigger reload, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected known remove to avoid global reload, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected known auth hash to be deleted")

View File

@@ -14,7 +14,11 @@ import (
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -30,6 +34,8 @@ const (
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
wsBodyLogMaxSize = 64 * 1024
wsBodyLogTruncated = "\n[websocket log truncated]\n"
)
var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -100,11 +106,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
var requestJSON []byte
@@ -139,6 +151,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
requestJSON = updated
}
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
updatedLastRequest = updated
}
lastRequest = updatedLastRequest
lastResponseOutput = []byte("[]")
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
wsTerminateErr = errWrite
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
return
}
continue
}
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
@@ -339,6 +367,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
return false
}
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
if h == nil || h.AuthManager == nil {
return false
}
resolvedModelName := modelName
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
if initialSuffix.HasSuffix {
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
} else {
resolvedModelName = resolvedBase
}
} else {
resolvedModelName = util.ResolveAutoModel(modelName)
}
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
providers := util.GetProviderName(baseModel)
if len(providers) == 0 && baseModel != resolvedModelName {
providers = util.GetProviderName(resolvedModelName)
}
if len(providers) == 0 {
return false
}
providerSet := make(map[string]struct{}, len(providers))
for i := 0; i < len(providers); i++ {
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
if providerKey == "" {
continue
}
providerSet[providerKey] = struct{}{}
}
if len(providerSet) == 0 {
return false
}
modelKey := baseModel
if modelKey == "" {
modelKey = strings.TrimSpace(resolvedModelName)
}
registryRef := registry.GetGlobalRegistry()
now := time.Now()
auths := h.AuthManager.List()
for i := 0; i < len(auths); i++ {
auth := auths[i]
if auth == nil {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
continue
}
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
continue
}
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return true
}
}
return false
}
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
if auth == nil {
return false
}
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
return false
}
if modelName != "" && len(auth.ModelStates) > 0 {
state, ok := auth.ModelStates[modelName]
if (!ok || state == nil) && modelName != "" {
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
if baseModel != "" && baseModel != modelName {
state, ok = auth.ModelStates[baseModel]
}
}
if ok && state != nil {
if state.Status == coreauth.StatusDisabled {
return false
}
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
return false
}
return true
}
}
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
return false
}
return true
}
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
return false
}
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
return false
}
generateResult := gjson.GetBytes(rawJSON, "generate")
return generateResult.Exists() && !generateResult.Bool()
}
func writeResponsesWebsocketSyntheticPrewarm(
c *gin.Context,
conn *websocket.Conn,
requestJSON []byte,
wsBodyLog *strings.Builder,
sessionID string,
) error {
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
if errPayloads != nil {
return errPayloads
}
for i := 0; i < len(payloads); i++ {
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
// websocket.TextMessage,
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(payloads[i]),
errWrite,
)
return errWrite
}
}
return nil
}
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
responseID := "resp_prewarm_" + uuid.NewString()
createdAt := time.Now().Unix()
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
var errSet error
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
if errSet != nil {
return nil, errSet
}
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
if errSet != nil {
return nil, errSet
}
if modelName != "" {
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
if errSet != nil {
return nil, errSet
}
}
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
if errSet != nil {
return nil, errSet
}
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
if errSet != nil {
return nil, errSet
}
if modelName != "" {
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
if errSet != nil {
return nil, errSet
}
}
return [][]byte{createdPayload, completedPayload}, nil
}
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
existingRaw = strings.TrimSpace(existingRaw)
appendRaw = strings.TrimSpace(appendRaw)
@@ -550,65 +764,134 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
}
body := handlers.BuildErrorResponseBody(status, errText)
payload := map[string]any{
"type": wsEventTypeError,
"status": status,
payload := []byte(`{}`)
var errSet error
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
if errSet != nil {
return nil, errSet
}
payload, errSet = sjson.SetBytes(payload, "status", status)
if errSet != nil {
return nil, errSet
}
if errMsg != nil && errMsg.Addon != nil {
headers := map[string]any{}
headers := []byte(`{}`)
hasHeaders := false
for key, values := range errMsg.Addon {
if len(values) == 0 {
continue
}
headers[key] = values[0]
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
if errSet != nil {
return nil, errSet
}
hasHeaders = true
}
if len(headers) > 0 {
payload["headers"] = headers
}
}
if len(body) > 0 && json.Valid(body) {
var decoded map[string]any
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
if inner, ok := decoded["error"]; ok {
payload["error"] = inner
} else {
payload["error"] = decoded
if hasHeaders {
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
if errSet != nil {
return nil, errSet
}
}
}
if _, ok := payload["error"]; !ok {
payload["error"] = map[string]any{
"type": "server_error",
"message": errText,
if len(body) > 0 && json.Valid(body) {
errorNode := gjson.GetBytes(body, "error")
if errorNode.Exists() {
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
} else {
payload, errSet = sjson.SetRawBytes(payload, "error", body)
}
if errSet != nil {
return nil, errSet
}
}
data, err := json.Marshal(payload)
if err != nil {
return nil, err
if !gjson.GetBytes(payload, "error").Exists() {
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
if errSet != nil {
return nil, errSet
}
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
if errSet != nil {
return nil, errSet
}
}
return data, conn.WriteMessage(websocket.TextMessage, data)
return payload, conn.WriteMessage(websocket.TextMessage, payload)
}
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
if builder == nil {
return
}
if builder.Len() >= wsBodyLogMaxSize {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
builder.WriteString("\n")
if !appendWebsocketLogString(builder, "\n") {
return
}
}
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
if !appendWebsocketLogString(builder, "websocket.") {
return
}
if !appendWebsocketLogString(builder, eventType) {
return
}
if !appendWebsocketLogString(builder, "\n") {
return
}
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
appendWebsocketLogString(builder, wsBodyLogTruncated)
return
}
appendWebsocketLogString(builder, "\n")
}
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.WriteString(value)
return true
}
builder.WriteString(value[:remaining])
return false
}
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.Write(value)
return true
}
limit := remaining - reserveForSuffix
if limit < 0 {
limit = 0
}
if limit > len(value) {
limit = len(value)
}
builder.Write(value[:limit])
return false
}
func websocketPayloadEventType(payload []byte) string {

View File

@@ -2,7 +2,9 @@ package openai
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
@@ -11,9 +13,46 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
type websocketCaptureExecutor struct {
streamCalls int
payloads [][]byte
}
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.streamCalls++
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
@@ -227,6 +266,34 @@ func TestAppendWebsocketEvent(t *testing.T) {
}
}
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
var builder strings.Builder
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
appendWebsocketEvent(&builder, "request", payload)
got := builder.String()
if len(got) > wsBodyLogMaxSize {
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
}
if !strings.Contains(got, wsBodyLogTruncated) {
t.Fatalf("expected truncation marker in body log")
}
}
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
initial := builder.String()
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
if builder.String() != initial {
t.Fatalf("builder grew after reaching limit")
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
@@ -326,3 +393,130 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
t.Fatalf("server error: %v", errServer)
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{
ID: "auth-ws",
Provider: "test-provider",
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
t.Fatalf("expected websocket-capable upstream for test-model")
}
}
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &websocketCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
errClose := conn.Close()
if errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
if errWrite != nil {
t.Fatalf("write prewarm websocket message: %v", errWrite)
}
_, createdPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read prewarm created message: %v", errReadMessage)
}
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
}
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
if prewarmResponseID == "" {
t.Fatalf("prewarm response id is empty")
}
if executor.streamCalls != 0 {
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
}
_, completedPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read prewarm completed message: %v", errReadMessage)
}
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
}
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
}
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
}
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
if errWrite != nil {
t.Fatalf("write follow-up websocket message: %v", errWrite)
}
_, upstreamPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read upstream completed message: %v", errReadMessage)
}
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
}
if executor.streamCalls != 1 {
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
}
if len(executor.payloads) != 1 {
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
}
forwarded := executor.payloads[0]
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
}
if gjson.GetBytes(forwarded, "generate").Exists() {
t.Fatalf("generate leaked upstream: %s", forwarded)
}
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
}
input := gjson.GetBytes(forwarded, "input").Array()
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected forwarded input: %s", forwarded)
}
}

View File

@@ -134,6 +134,7 @@ type Manager struct {
hook Hook
mu sync.RWMutex
auths map[string]*Auth
scheduler *authScheduler
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int
@@ -149,6 +150,9 @@ type Manager struct {
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
apiKeyModelAlias atomic.Value
// modelPoolOffsets tracks per-auth alias pool rotation state.
modelPoolOffsets map[string]int
// runtimeConfig stores the latest application config for request-time decisions.
// It is initialized in NewManager; never Load() before first Store().
runtimeConfig atomic.Value
@@ -176,14 +180,59 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
hook: hook,
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{})
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
manager.scheduler = newAuthScheduler(selector)
return manager
}
func isBuiltInSelector(selector Selector) bool {
switch selector.(type) {
case *RoundRobinSelector, *FillFirstSelector:
return true
default:
return false
}
}
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
if m == nil || m.scheduler == nil {
return
}
m.scheduler.rebuild(auths)
}
func (m *Manager) syncScheduler() {
if m == nil || m.scheduler == nil {
return
}
m.syncSchedulerFromSnapshot(m.snapshotAuths())
}
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
// supportedModelSet is rebuilt from the current global model registry state.
// This must be called after models have been registered for a newly added auth,
// because the initial scheduler.upsertAuth during Register/Update runs before
// registerModelsForAuth and therefore snapshots an empty model set.
func (m *Manager) RefreshSchedulerEntry(authID string) {
if m == nil || m.scheduler == nil || authID == "" {
return
}
m.mu.RLock()
auth, ok := m.auths[authID]
if !ok || auth == nil {
m.mu.RUnlock()
return
}
snapshot := auth.Clone()
m.mu.RUnlock()
m.scheduler.upsertAuth(snapshot)
}
func (m *Manager) SetSelector(selector Selector) {
if m == nil {
return
@@ -194,6 +243,10 @@ func (m *Manager) SetSelector(selector Selector) {
m.mu.Lock()
m.selector = selector
m.mu.Unlock()
if m.scheduler != nil {
m.scheduler.setSelector(selector)
m.syncScheduler()
}
}
// SetStore swaps the underlying persistence store.
@@ -251,16 +304,323 @@ func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) strin
if resolved == "" {
return ""
}
// Preserve thinking suffix from the client's requested model unless config already has one.
requestResult := thinking.ParseSuffix(requestedModel)
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
return preserveRequestedModelSuffix(requestedModel, resolved)
}
func isAPIKeyAuth(auth *Auth) bool {
if auth == nil {
return false
}
kind, _ := auth.AccountInfo()
return strings.EqualFold(strings.TrimSpace(kind), "api_key")
}
func isOpenAICompatAPIKeyAuth(auth *Auth) bool {
if !isAPIKeyAuth(auth) {
return false
}
if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return true
}
if auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["compat_name"]) != ""
}
func openAICompatProviderKey(auth *Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" {
return strings.ToLower(providerKey)
}
if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" {
return strings.ToLower(compatName)
}
}
return strings.ToLower(strings.TrimSpace(auth.Provider))
}
func openAICompatModelPoolKey(auth *Auth, requestedModel string) string {
base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName)
if base == "" {
base = strings.TrimSpace(requestedModel)
}
return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base)
}
func (m *Manager) nextModelPoolOffset(key string, size int) int {
if m == nil || size <= 1 {
return 0
}
key = strings.TrimSpace(key)
if key == "" {
return 0
}
m.mu.Lock()
defer m.mu.Unlock()
if m.modelPoolOffsets == nil {
m.modelPoolOffsets = make(map[string]int)
}
offset := m.modelPoolOffsets[key]
if offset >= 2_147_483_640 {
offset = 0
}
m.modelPoolOffsets[key] = offset + 1
if size <= 0 {
return 0
}
return offset % size
}
func rotateStrings(values []string, offset int) []string {
if len(values) <= 1 {
return values
}
if offset <= 0 {
out := make([]string, len(values))
copy(out, values)
return out
}
offset = offset % len(values)
out := make([]string, 0, len(values))
out = append(out, values[offset:]...)
out = append(out, values[:offset]...)
return out
}
func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string {
if m == nil || !isOpenAICompatAPIKeyAuth(auth) {
return nil
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return nil
}
return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func preserveRequestedModelSuffix(requestedModel, resolved string) string {
return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel))
}
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
return m.prepareExecutionModels(auth, routeModel)
}
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
requestedModel := rewriteModelForAuth(routeModel, auth)
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
if len(pool) == 1 {
return pool
}
offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool))
return rotateStrings(pool, offset)
}
resolved := m.applyAPIKeyModelAlias(auth, requestedModel)
if strings.TrimSpace(resolved) == "" {
resolved = requestedModel
}
return []string{resolved}
}
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
if ch == nil {
return
}
go func() {
for range ch {
}
}()
}
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
if ch == nil {
return nil, true, nil
}
buffered := make([]cliproxyexecutor.StreamChunk, 0, 1)
for {
var (
chunk cliproxyexecutor.StreamChunk
ok bool
)
if ctx != nil {
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case chunk, ok = <-ch:
}
} else {
chunk, ok = <-ch
}
if !ok {
return buffered, true, nil
}
if chunk.Err != nil {
return nil, false, chunk.Err
}
buffered = append(buffered, chunk)
if len(chunk.Payload) > 0 {
return buffered, false, nil
}
}
}
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
var failed bool
forward := true
emit := func(chunk cliproxyexecutor.StreamChunk) bool {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
return false
}
if ctx == nil {
out <- chunk
return true
}
select {
case <-ctx.Done():
forward = false
return false
case out <- chunk:
return true
}
}
for _, chunk := range buffered {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
for chunk := range remaining {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
if !failed {
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
}
}()
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
}
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
execModels := m.prepareExecutionModels(auth, routeModel)
var lastErr error
for idx, execModel := range execModels {
execReq := req
execReq.Model = execModel
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
if errStream != nil {
if errCtx := ctx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(ctx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks)
if bootstrapErr != nil {
if errCtx := ctx.Err(); errCtx != nil {
discardStreamChunks(streamResult.Chunks)
return nil, errCtx
}
if isRequestInvalidError(bootstrapErr) {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
return nil, bootstrapErr
}
if idx < len(execModels)-1 {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
lastErr = bootstrapErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
if closed && len(buffered) == 0 {
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
m.MarkResult(ctx, result)
if idx < len(execModels)-1 {
lastErr = emptyErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
remaining := streamResult.Chunks
if closed {
closedCh := make(chan cliproxyexecutor.StreamChunk)
close(closedCh)
remaining = closedCh
}
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
}
if lastErr == nil {
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
}
return nil, lastErr
}
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
@@ -448,10 +808,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
auth.ID = uuid.NewString()
}
auth.EnsureIndex()
authClone := auth.Clone()
m.mu.Lock()
m.auths[auth.ID] = auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
@@ -473,9 +837,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
}
}
auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone()
authClone := auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
@@ -484,12 +852,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
// Load resets manager state from the backing store.
func (m *Manager) Load(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.store == nil {
m.mu.Unlock()
return nil
}
items, err := m.store.List(ctx)
if err != nil {
m.mu.Unlock()
return err
}
m.auths = make(map[string]*Auth, len(items))
@@ -505,6 +874,8 @@ func (m *Manager) Load(ctx context.Context) error {
cfg = &internalconfig.Config{}
}
m.rebuildAPIKeyModelAliasLocked(cfg)
m.mu.Unlock()
m.syncScheduler()
return nil
}
@@ -634,32 +1005,42 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = errExec
lastErr = authErr
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
@@ -696,32 +1077,42 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.hook.OnResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.hook.OnResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = errExec
lastErr = authErr
continue
}
m.hook.OnResult(execCtx, result)
return resp, nil
}
}
@@ -758,63 +1149,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
forward := true
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
continue
}
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
}
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
return &cliproxyexecutor.StreamResult{
Headers: streamResult.Headers,
Chunks: out,
}, nil
return streamResult, nil
}
}
@@ -1245,6 +1591,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
suspendReason := ""
clearModelQuota := false
setModelQuota := false
var authSnapshot *Auth
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
@@ -1338,8 +1685,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
}
_ = m.persist(ctx, auth)
authSnapshot = auth.Clone()
}
m.mu.Unlock()
if m.scheduler != nil && authSnapshot != nil {
m.scheduler.upsertAuth(authSnapshot)
}
if clearModelQuota && result.Model != "" {
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
@@ -1533,18 +1884,22 @@ func statusCodeFromResult(err *Error) int {
}
// isRequestInvalidError returns true if the error represents a client request
// error that should not be retried. Specifically, it checks for 400 Bad Request
// with "invalid_request_error" in the message, indicating the request itself is
// malformed and switching to a different auth will not help.
// error that should not be retried. Specifically, it treats 400 responses with
// "invalid_request_error" and all 422 responses as request-shape failures,
// where switching auths or pooled upstream models will not help.
func isRequestInvalidError(err error) bool {
if err == nil {
return false
}
status := statusCodeFromError(err)
if status != http.StatusBadRequest {
switch status {
case http.StatusBadRequest:
return strings.Contains(err.Error(), "invalid_request_error")
case http.StatusUnprocessableEntity:
return true
default:
return false
}
return strings.Contains(err.Error(), "invalid_request_error")
}
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
@@ -1692,7 +2047,29 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
}
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
func (m *Manager) useSchedulerFastPath() bool {
if m == nil || m.scheduler == nil {
return false
}
return isBuiltInSelector(m.selector)
}
func shouldRetrySchedulerPick(err error) bool {
if err == nil {
return false
}
var cooldownErr *modelCooldownError
if errors.As(err, &cooldownErr) {
return true
}
var authErr *Error
if !errors.As(err, &authErr) || authErr == nil {
return false
}
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
}
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock()
@@ -1752,7 +2129,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
return authCopy, executor, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
if !m.useSchedulerFastPath() {
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
executor, okExecutor := m.Executor(provider)
if !okExecutor {
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
}
if errPick != nil {
return nil, nil, errPick
}
if selected == nil {
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers))
@@ -1835,6 +2243,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
return authCopy, executor, providerKey, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
if !m.useSchedulerFastPath() {
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
}
eligibleProviders := make([]string, 0, len(providers))
seenProviders := make(map[string]struct{}, len(providers))
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
if _, seen := seenProviders[providerKey]; seen {
continue
}
if _, okExecutor := m.Executor(providerKey); !okExecutor {
continue
}
seenProviders[providerKey] = struct{}{}
eligibleProviders = append(eligibleProviders, providerKey)
}
if len(eligibleProviders) == 0 {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
}
if errPick != nil {
return nil, nil, "", errPick
}
if selected == nil {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
executor, okExecutor := m.Executor(providerKey)
if !okExecutor {
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil {
return nil
@@ -2186,6 +2646,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()}
m.auths[id] = current
if m.scheduler != nil {
m.scheduler.upsertAuth(current.Clone())
}
}
m.mu.Unlock()
return

View File

@@ -0,0 +1,163 @@
package auth
import (
"context"
"errors"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerProviderTestExecutor struct {
provider string
}
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
ctx := context.Background()
testCases := []struct {
name string
prime func(*Manager, *Auth) error
}{
{
name: "register",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
return errRegister
},
},
{
name: "update",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
if errRegister != nil {
return errRegister
}
updated := auth.Clone()
updated.Metadata = map[string]any{"updated": true}
_, errUpdate := manager.Update(ctx, updated)
return errUpdate
},
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
manager := NewManager(nil, &RoundRobinSelector{}, nil)
auth := &Auth{
ID: "refresh-entry-" + testCase.name,
Provider: "gemini",
}
if errPrime := testCase.prime(manager, auth); errPrime != nil {
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
}
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
var authErr *Error
if !errors.As(errPick, &authErr) || authErr == nil {
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
}
if authErr.Code != "auth_not_found" {
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
}
if got != nil {
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
}
manager.RefreshSchedulerEntry(auth.ID)
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() after refresh error = %v", errPick)
}
if got == nil || got.ID != auth.ID {
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
}
})
}
}
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
ctx := context.Background()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
oldAuth := &Auth{
ID: "cooldown-stale-old",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
t.Fatalf("register old auth: %v", errRegister)
}
manager.MarkResult(ctx, Result{
AuthID: oldAuth.ID,
Provider: "gemini",
Model: "scheduler-cooldown-rebuild-model",
Success: false,
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
})
newAuth := &Auth{
ID: "cooldown-stale-new",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
t.Fatalf("register new auth: %v", errRegister)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
t.Cleanup(func() {
reg.UnregisterClient(newAuth.ID)
})
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
var cooldownErr *modelCooldownError
if !errors.As(errPick, &cooldownErr) {
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
}
if got != nil {
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
}
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if executor == nil {
t.Fatal("pickNext() executor = nil")
}
if got == nil || got.ID != newAuth.ID {
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
}
}

View File

@@ -80,54 +80,98 @@ func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string
return upstreamModel
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
return thinking.SuffixResult{}, nil
}
if len(models) == 0 {
return ""
}
requestResult := thinking.ParseSuffix(requestedModel)
base := requestResult.ModelName
if base == "" {
base = requestedModel
}
candidates := []string{base}
if base != requestedModel {
candidates = append(candidates, requestedModel)
}
return requestResult, candidates
}
preserveSuffix := func(resolved string) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
}
func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
if len(models) == 0 {
return nil
}
requestResult, candidates := modelAliasLookupCandidates(requestedModel)
if len(candidates) == 0 {
return nil
}
out := make([]string, 0)
seen := make(map[string]struct{})
for i := range models {
name := strings.TrimSpace(models[i].GetName())
alias := strings.TrimSpace(models[i].GetAlias())
for _, candidate := range candidates {
if candidate == "" {
if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) {
continue
}
if alias != "" && strings.EqualFold(alias, candidate) {
if name != "" {
return preserveSuffix(name)
}
return preserveSuffix(candidate)
resolved := candidate
if name != "" {
resolved = name
}
if name != "" && strings.EqualFold(name, candidate) {
return preserveSuffix(name)
resolved = preserveResolvedModelSuffix(resolved, requestResult)
key := strings.ToLower(strings.TrimSpace(resolved))
if key == "" {
break
}
if _, exists := seen[key]; exists {
break
}
seen[key] = struct{}{}
out = append(out, resolved)
break
}
}
if len(out) > 0 {
return out
}
for i := range models {
name := strings.TrimSpace(models[i].GetName())
for _, candidate := range candidates {
if candidate == "" || name == "" || !strings.EqualFold(name, candidate) {
continue
}
return []string{preserveResolvedModelSuffix(name, requestResult)}
}
}
return nil
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models)
if len(resolved) > 0 {
return resolved[0]
}
return ""
}

View File

@@ -0,0 +1,419 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type openAICompatPoolExecutor struct {
id string
mu sync.Mutex
executeModels []string
countModels []string
streamModels []string
executeErrors map[string]error
countErrors map[string]error
streamFirstErrors map[string]error
streamPayloads map[string][]cliproxyexecutor.StreamChunk
}
func (e *openAICompatPoolExecutor) Identifier() string { return e.id }
func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
err := e.executeErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.streamModels = append(e.streamModels, req.Model)
err := e.streamFirstErrors[req.Model]
payloadChunks, hasCustomChunks := e.streamPayloads[req.Model]
chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...)
e.mu.Unlock()
ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks)))
if err != nil {
ch <- cliproxyexecutor.StreamChunk{Err: err}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
if !hasCustomChunks {
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)}
} else {
for _, chunk := range chunks {
ch <- chunk
}
}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.countModels = append(e.countModels, req.Model)
err := e.countErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
_ = ctx
_ = auth
_ = req
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
}
func (e *openAICompatPoolExecutor) ExecuteModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeModels))
copy(out, e.executeModels)
return out
}
func (e *openAICompatPoolExecutor) CountModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.countModels))
copy(out, e.countModels)
return out
}
func (e *openAICompatPoolExecutor) StreamModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.streamModels))
copy(out, e.streamModels)
return out
}
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
t.Helper()
cfg := &internalconfig.Config{
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
Name: "pool",
Models: models,
}},
}
m := NewManager(nil, nil, nil)
m.SetConfig(cfg)
if executor == nil {
executor = &openAICompatPoolExecutor{id: "pool"}
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "pool-auth-" + t.Name(),
Provider: "pool",
Status: StatusActive,
Attributes: map[string]string{
"api_key": "test-key",
"compat_name": "pool",
"provider_key": "pool",
},
}
if _, err := m.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
t.Cleanup(func() {
reg.UnregisterClient(auth.ID)
})
return m
}
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
}
got := executor.CountModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("count calls = %v, want only first invalid model", got)
}
}
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
models := []modelAliasEntry{
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
}
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
if len(got) != len(want) {
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 3; i++ {
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute %d returned empty payload", i)
}
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute error = %v, want %v", err, invalidErr)
}
got := executor.ExecuteModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("execute calls = %v, want only first invalid model", got)
}
}
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute: %v", err)
}
if string(resp.Payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
"qwen3.5-plus": {},
},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5")
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
}
got := executor.StreamModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first invalid model", got)
}
}
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 2; i++ {
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute count %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute count %d returned empty payload", i)
}
}
got := executor.CountModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil {
t.Fatal("expected invalid request error")
}
if err != invalidErr {
t.Fatalf("error = %v, want %v", err, invalidErr)
}
if streamResult != nil {
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
}
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first upstream model", got)
}
}

View File

@@ -0,0 +1,904 @@
package auth
import (
"context"
"sort"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
type schedulerStrategy int
const (
schedulerStrategyCustom schedulerStrategy = iota
schedulerStrategyRoundRobin
schedulerStrategyFillFirst
)
// scheduledState describes how an auth currently participates in a model shard.
type scheduledState int
const (
scheduledStateReady scheduledState = iota
scheduledStateCooldown
scheduledStateBlocked
scheduledStateDisabled
)
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
type authScheduler struct {
mu sync.Mutex
strategy schedulerStrategy
providers map[string]*providerScheduler
authProviders map[string]string
mixedCursors map[string]int
}
// providerScheduler stores auth metadata and model shards for a single provider.
type providerScheduler struct {
providerKey string
auths map[string]*scheduledAuthMeta
modelShards map[string]*modelScheduler
}
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
type scheduledAuthMeta struct {
auth *Auth
providerKey string
priority int
virtualParent string
websocketEnabled bool
supportedModelSet map[string]struct{}
}
// modelScheduler tracks ready and blocked auths for one provider/model combination.
type modelScheduler struct {
modelKey string
entries map[string]*scheduledAuth
priorityOrder []int
readyByPriority map[int]*readyBucket
blocked cooldownQueue
}
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
type scheduledAuth struct {
meta *scheduledAuthMeta
auth *Auth
state scheduledState
nextRetryAt time.Time
}
// readyBucket keeps the ready views for one priority level.
type readyBucket struct {
all readyView
ws readyView
}
// readyView holds the selection order for flat or grouped round-robin traversal.
type readyView struct {
flat []*scheduledAuth
cursor int
parentOrder []string
parentCursor int
children map[string]*childBucket
}
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
type childBucket struct {
items []*scheduledAuth
cursor int
}
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
type cooldownQueue []*scheduledAuth
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
func newAuthScheduler(selector Selector) *authScheduler {
return &authScheduler{
strategy: selectorStrategy(selector),
providers: make(map[string]*providerScheduler),
authProviders: make(map[string]string),
mixedCursors: make(map[string]int),
}
}
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
func selectorStrategy(selector Selector) schedulerStrategy {
switch selector.(type) {
case *FillFirstSelector:
return schedulerStrategyFillFirst
case nil, *RoundRobinSelector:
return schedulerStrategyRoundRobin
default:
return schedulerStrategyCustom
}
}
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
func (s *authScheduler) setSelector(selector Selector) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.strategy = selectorStrategy(selector)
clear(s.mixedCursors)
}
// rebuild recreates the complete scheduler state from an auth snapshot.
func (s *authScheduler) rebuild(auths []*Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.providers = make(map[string]*providerScheduler)
s.authProviders = make(map[string]string)
s.mixedCursors = make(map[string]int)
now := time.Now()
for _, auth := range auths {
s.upsertAuthLocked(auth, now)
}
}
// upsertAuth incrementally synchronizes one auth into the scheduler.
func (s *authScheduler) upsertAuth(auth *Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.upsertAuthLocked(auth, time.Now())
}
// removeAuth deletes one auth from every scheduler shard that references it.
func (s *authScheduler) removeAuth(authID string) {
if s == nil {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.removeAuthLocked(authID)
}
// pickSingle returns the next auth for a single provider/model request using scheduler state.
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
if s == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerKey := strings.ToLower(strings.TrimSpace(provider))
modelKey := canonicalModelKey(model)
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
s.mu.Lock()
defer s.mu.Unlock()
providerState := s.providers[providerKey]
if providerState == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
if shard == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) > 0 {
if _, ok := tried[entry.auth.ID]; ok {
return false
}
}
return true
}
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
return picked, nil
}
return nil, shard.unavailableErrorLocked(provider, model, predicate)
}
// pickMixed returns the next auth and provider for a mixed-provider request.
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
if s == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
normalized := normalizeProviderKeys(providers)
if len(normalized) == 0 {
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
modelKey := canonicalModelKey(model)
s.mu.Lock()
defer s.mu.Unlock()
if pinnedAuthID != "" {
providerKey := s.authProviders[pinnedAuthID]
if providerKey == "" || !containsProvider(normalized, providerKey) {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerState := s.providers[providerKey]
if providerState == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) == 0 {
return true
}
_, ok := tried[pinnedAuthID]
return !ok
}
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
return picked, providerKey, nil
}
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
}
predicate := triedPredicate(tried)
candidateShards := make([]*modelScheduler, len(normalized))
bestPriority := 0
hasCandidate := false
now := time.Now()
for providerIndex, providerKey := range normalized {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(modelKey, now)
candidateShards[providerIndex] = shard
if shard == nil {
continue
}
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
if !okPriority {
continue
}
if !hasCandidate || priorityReady > bestPriority {
bestPriority = priorityReady
hasCandidate = true
}
}
if !hasCandidate {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
if s.strategy == schedulerStrategyFillFirst {
for providerIndex, providerKey := range normalized {
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
if picked != nil {
return picked, providerKey, nil
}
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
start := 0
if len(normalized) > 0 {
start = s.mixedCursors[cursorKey] % len(normalized)
}
for offset := 0; offset < len(normalized); offset++ {
providerIndex := (start + offset) % len(normalized)
providerKey := normalized[providerIndex]
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
if picked == nil {
continue
}
s.mixedCursors[cursorKey] = providerIndex + 1
return picked, providerKey, nil
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
now := time.Now()
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, providerKey := range providers {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
if shard == nil {
continue
}
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
total += localTotal
cooldownCount += localCooldownCount
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
earliest = localEarliest
}
}
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, "", resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// triedPredicate builds a filter that excludes auths already attempted for the current request.
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
if len(tried) == 0 {
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
}
return func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
_, ok := tried[entry.auth.ID]
return !ok
}
}
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
func normalizeProviderKeys(providers []string) []string {
seen := make(map[string]struct{}, len(providers))
out := make([]string, 0, len(providers))
for _, provider := range providers {
providerKey := strings.ToLower(strings.TrimSpace(provider))
if providerKey == "" {
continue
}
if _, ok := seen[providerKey]; ok {
continue
}
seen[providerKey] = struct{}{}
out = append(out, providerKey)
}
return out
}
// containsProvider reports whether provider is present in the normalized provider list.
func containsProvider(providers []string, provider string) bool {
for _, candidate := range providers {
if candidate == provider {
return true
}
}
return false
}
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
if auth == nil {
return
}
authID := strings.TrimSpace(auth.ID)
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
if authID == "" || providerKey == "" || auth.Disabled {
s.removeAuthLocked(authID)
return
}
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
if previousState := s.providers[previousProvider]; previousState != nil {
previousState.removeAuthLocked(authID)
}
}
meta := buildScheduledAuthMeta(auth)
s.authProviders[authID] = providerKey
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
}
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
func (s *authScheduler) removeAuthLocked(authID string) {
if authID == "" {
return
}
if providerKey := s.authProviders[authID]; providerKey != "" {
if providerState := s.providers[providerKey]; providerState != nil {
providerState.removeAuthLocked(authID)
}
delete(s.authProviders, authID)
}
}
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
if s.providers == nil {
s.providers = make(map[string]*providerScheduler)
}
providerState := s.providers[providerKey]
if providerState == nil {
providerState = &providerScheduler{
providerKey: providerKey,
auths: make(map[string]*scheduledAuthMeta),
modelShards: make(map[string]*modelScheduler),
}
s.providers[providerKey] = providerState
}
return providerState
}
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
virtualParent := ""
if auth.Attributes != nil {
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
}
return &scheduledAuthMeta{
auth: auth,
providerKey: providerKey,
priority: authPriority(auth),
virtualParent: virtualParent,
websocketEnabled: authWebsocketsEnabled(auth),
supportedModelSet: supportedModelSetForAuth(auth.ID),
}
}
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
func supportedModelSetForAuth(authID string) map[string]struct{} {
authID = strings.TrimSpace(authID)
if authID == "" {
return nil
}
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
if len(models) == 0 {
return nil
}
set := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
modelKey := canonicalModelKey(model.ID)
if modelKey == "" {
continue
}
set[modelKey] = struct{}{}
}
return set
}
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
if p == nil || meta == nil || meta.auth == nil {
return
}
p.auths[meta.auth.ID] = meta
for modelKey, shard := range p.modelShards {
if shard == nil {
continue
}
if !meta.supportsModel(modelKey) {
shard.removeEntryLocked(meta.auth.ID)
continue
}
shard.upsertEntryLocked(meta, now)
}
}
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
func (p *providerScheduler) removeAuthLocked(authID string) {
if p == nil || authID == "" {
return
}
delete(p.auths, authID)
for _, shard := range p.modelShards {
if shard != nil {
shard.removeEntryLocked(authID)
}
}
}
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
if p == nil {
return nil
}
modelKey = canonicalModelKey(modelKey)
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
shard.promoteExpiredLocked(now)
return shard
}
shard := &modelScheduler{
modelKey: modelKey,
entries: make(map[string]*scheduledAuth),
readyByPriority: make(map[int]*readyBucket),
}
for _, meta := range p.auths {
if meta == nil || !meta.supportsModel(modelKey) {
continue
}
shard.upsertEntryLocked(meta, now)
}
p.modelShards[modelKey] = shard
return shard
}
// supportsModel reports whether the auth metadata currently supports modelKey.
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
modelKey = canonicalModelKey(modelKey)
if modelKey == "" {
return true
}
if len(m.supportedModelSet) == 0 {
return false
}
_, ok := m.supportedModelSet[modelKey]
return ok
}
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
if m == nil || meta == nil || meta.auth == nil {
return
}
entry, ok := m.entries[meta.auth.ID]
if !ok || entry == nil {
entry = &scheduledAuth{}
m.entries[meta.auth.ID] = entry
}
previousState := entry.state
previousNextRetryAt := entry.nextRetryAt
previousPriority := 0
previousParent := ""
previousWebsocketEnabled := false
if entry.meta != nil {
previousPriority = entry.meta.priority
previousParent = entry.meta.virtualParent
previousWebsocketEnabled = entry.meta.websocketEnabled
}
entry.meta = meta
entry.auth = meta.auth
entry.nextRetryAt = time.Time{}
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
return
}
m.rebuildIndexesLocked()
}
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
func (m *modelScheduler) removeEntryLocked(authID string) {
if m == nil || authID == "" {
return
}
if _, ok := m.entries[authID]; !ok {
return
}
delete(m.entries, authID)
m.rebuildIndexesLocked()
}
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
if m == nil || len(m.blocked) == 0 {
return
}
changed := false
for _, entry := range m.blocked {
if entry == nil || entry.auth == nil {
continue
}
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
continue
}
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
entry.nextRetryAt = time.Time{}
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
entry.nextRetryAt = time.Time{}
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
changed = true
}
if changed {
m.rebuildIndexesLocked()
}
}
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
m.promoteExpiredLocked(time.Now())
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
if !okPriority {
return nil
}
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
}
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
if m == nil {
return 0, false
}
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
if view.pickFirst(predicate) != nil {
return priority, true
}
}
return 0, false
}
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
bucket := m.readyByPriority[priority]
if bucket == nil {
return nil
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
var picked *scheduledAuth
if strategy == schedulerStrategyFillFirst {
picked = view.pickFirst(predicate)
} else {
picked = view.pickRoundRobin(predicate)
}
if picked == nil || picked.auth == nil {
return nil
}
return picked.auth
}
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
now := time.Now()
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, providerForError, resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
if m == nil {
return 0, 0, time.Time{}
}
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, entry := range m.entries {
if predicate != nil && !predicate(entry) {
continue
}
total++
if entry == nil || entry.auth == nil {
continue
}
if entry.state != scheduledStateCooldown {
continue
}
cooldownCount++
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
earliest = entry.nextRetryAt
}
}
return total, cooldownCount, earliest
}
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
func (m *modelScheduler) rebuildIndexesLocked() {
m.readyByPriority = make(map[int]*readyBucket)
m.priorityOrder = m.priorityOrder[:0]
m.blocked = m.blocked[:0]
priorityBuckets := make(map[int][]*scheduledAuth)
for _, entry := range m.entries {
if entry == nil || entry.auth == nil {
continue
}
switch entry.state {
case scheduledStateReady:
priority := entry.meta.priority
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
case scheduledStateCooldown, scheduledStateBlocked:
m.blocked = append(m.blocked, entry)
}
}
for priority, entries := range priorityBuckets {
sort.Slice(entries, func(i, j int) bool {
return entries[i].auth.ID < entries[j].auth.ID
})
m.readyByPriority[priority] = buildReadyBucket(entries)
m.priorityOrder = append(m.priorityOrder, priority)
}
sort.Slice(m.priorityOrder, func(i, j int) bool {
return m.priorityOrder[i] > m.priorityOrder[j]
})
sort.Slice(m.blocked, func(i, j int) bool {
left := m.blocked[i]
right := m.blocked[j]
if left == nil || right == nil {
return left != nil
}
if left.nextRetryAt.Equal(right.nextRetryAt) {
return left.auth.ID < right.auth.ID
}
if left.nextRetryAt.IsZero() {
return false
}
if right.nextRetryAt.IsZero() {
return true
}
return left.nextRetryAt.Before(right.nextRetryAt)
})
}
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
bucket := &readyBucket{}
bucket.all = buildReadyView(entries)
wsEntries := make([]*scheduledAuth, 0, len(entries))
for _, entry := range entries {
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
wsEntries = append(wsEntries, entry)
}
}
bucket.ws = buildReadyView(wsEntries)
return bucket
}
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
func buildReadyView(entries []*scheduledAuth) readyView {
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
if len(entries) == 0 {
return view
}
groups := make(map[string][]*scheduledAuth)
for _, entry := range entries {
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
return view
}
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
}
if len(groups) <= 1 {
return view
}
view.children = make(map[string]*childBucket, len(groups))
view.parentOrder = make([]string, 0, len(groups))
for parent := range groups {
view.parentOrder = append(view.parentOrder, parent)
}
sort.Strings(view.parentOrder)
for _, parent := range view.parentOrder {
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
}
return view
}
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
for _, entry := range v.flat {
if predicate == nil || predicate(entry) {
return entry
}
}
return nil
}
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
if len(v.parentOrder) > 1 && len(v.children) > 0 {
return v.pickGroupedRoundRobin(predicate)
}
if len(v.flat) == 0 {
return nil
}
start := 0
if len(v.flat) > 0 {
start = v.cursor % len(v.flat)
}
for offset := 0; offset < len(v.flat); offset++ {
index := (start + offset) % len(v.flat)
entry := v.flat[index]
if predicate != nil && !predicate(entry) {
continue
}
v.cursor = index + 1
return entry
}
return nil
}
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
start := 0
if len(v.parentOrder) > 0 {
start = v.parentCursor % len(v.parentOrder)
}
for offset := 0; offset < len(v.parentOrder); offset++ {
parentIndex := (start + offset) % len(v.parentOrder)
parent := v.parentOrder[parentIndex]
child := v.children[parent]
if child == nil || len(child.items) == 0 {
continue
}
itemStart := child.cursor % len(child.items)
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
itemIndex := (itemStart + itemOffset) % len(child.items)
entry := child.items[itemIndex]
if predicate != nil && !predicate(entry) {
continue
}
child.cursor = itemIndex + 1
v.parentCursor = parentIndex + 1
return entry
}
}
return nil
}

View File

@@ -0,0 +1,216 @@
package auth
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerBenchmarkExecutor struct {
id string
}
func (e schedulerBenchmarkExecutor) Identifier() string { return e.id }
func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) {
b.Helper()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
providers := []string{"gemini"}
manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"}
if mixed {
providers = []string{"gemini", "claude"}
manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"}
}
reg := registry.GetGlobalRegistry()
model := "bench-model"
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider}
if withPriority {
priority := "0"
if index%2 == 0 {
priority = "10"
}
auth.Attributes = map[string]string{"priority": priority}
}
_, errRegister := manager.Register(context.Background(), auth)
if errRegister != nil {
b.Fatalf("Register(%s) error = %v", auth.ID, errRegister)
}
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}})
}
manager.syncScheduler()
b.Cleanup(func() {
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index))
}
})
return manager, providers, model
}
func BenchmarkManagerPickNext500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNext1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextMixed500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil {
b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick)
}
manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true})
}
}

View File

@@ -0,0 +1,503 @@
package auth
import (
"context"
"net/http"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerTestExecutor struct{}
func (schedulerTestExecutor) Identifier() string { return "test" }
func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
type trackingSelector struct {
calls int
lastAuthID []string
}
func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
s.calls++
s.lastAuthID = s.lastAuthID[:0]
for _, auth := range auths {
s.lastAuthID = append(s.lastAuthID, auth.ID)
}
if len(auths) == 0 {
return nil, nil
}
return auths[len(auths)-1], nil
}
func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler {
scheduler := newAuthScheduler(selector)
scheduler.rebuild(auths)
return scheduler
}
func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) {
t.Helper()
reg := registry.GetGlobalRegistry()
for _, authID := range authIDs {
reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}})
}
t.Cleanup(func() {
for _, authID := range authIDs {
reg.UnregisterClient(authID)
}
})
}
func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}},
&Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
&Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
)
want := []string{"high-a", "high-b", "high-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&FillFirstSelector{},
&Auth{ID: "b", Provider: "gemini"},
&Auth{ID: "a", Provider: "gemini"},
&Auth{ID: "c", Provider: "gemini"},
)
for index := 0; index < 3; index++ {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != "a" {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a")
}
}
}
func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) {
t.Parallel()
model := "gemini-2.5-pro"
registerSchedulerModels(t, "gemini", model, "cooldown-expired")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{
ID: "cooldown-expired",
Provider: "gemini",
ModelStates: map[string]*ModelState{
model: {
Status: StatusError,
Unavailable: true,
NextRetryAfter: time.Now().Add(-1 * time.Second),
},
},
},
)
got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickSingle() auth = nil")
}
if got.ID != "cooldown-expired" {
t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired")
}
}
func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) {
t.Parallel()
registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
&Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
)
wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"}
wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"}
for index := range wantIDs {
got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantIDs[index] {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
if got.Attributes["gemini_virtual_parent"] != wantParents[index] {
t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index])
}
}
}
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "codex-http", Provider: "codex"},
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
)
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "gemini-a", Provider: "gemini"},
&Auth{ID: "gemini-b", Provider: "gemini"},
&Auth{ID: "claude-a", Provider: "claude"},
)
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
t.Parallel()
model := "gpt-default"
registerSchedulerModels(t, "provider-low", model, "low")
registerSchedulerModels(t, "provider-high-a", model, "high-a")
registerSchedulerModels(t, "provider-high-b", model, "high-b")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
)
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) {
t.Parallel()
selector := &trackingSelector{}
manager := NewManager(nil, selector, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"}
manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"}
got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNext() auth = nil")
}
if selector.calls != 1 {
t.Fatalf("selector.calls = %d, want %d", selector.calls, 1)
}
if len(selector.lastAuthID) != 2 {
t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2)
}
if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] {
t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1])
}
}
func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if manager.scheduler == nil {
t.Fatalf("manager.scheduler = nil")
}
if manager.scheduler.strategy != schedulerStrategyRoundRobin {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin)
}
manager.SetSelector(&FillFirstSelector{})
if manager.scheduler.strategy != schedulerStrategyFillFirst {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst)
}
}
func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() error = %v", errPick)
}
if got == nil || got.ID != "auth-a" {
t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got)
}
if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil {
t.Fatalf("Update(auth-a) error = %v", errUpdate)
}
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after update error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got)
}
}
func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() auth = nil")
}
if provider != "claude" {
t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude")
}
if got.ID != "claude-a" {
t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a")
}
}
func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
reg.UnregisterClient("auth-a")
reg.UnregisterClient("auth-b")
})
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: false,
Error: &Error{HTTPStatus: 429, Message: "quota"},
})
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: true,
})
seen := make(map[string]struct{}, 2)
for index := 0; index < 2; index++ {
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index)
}
seen[got.ID] = struct{}{}
}
if len(seen) != 2 {
t.Fatalf("len(seen) = %d, want %d", len(seen), 2)
}
}

View File

@@ -323,6 +323,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
// This operation may block on network calls, but the auth configuration
// is already effective at this point.
s.registerModelsForAuth(auth)
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
// from the now-populated global model registry. Without this, newly added auths
// have an empty supportedModelSet (because Register/Update upserts into the
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
s.coreManager.RefreshSchedulerEntry(auth.ID)
}
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {