mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-26 05:36:12 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
046865461e | ||
|
|
cf74ed2f0c | ||
|
|
e333fbea3d | ||
|
|
efbe36d1d4 | ||
|
|
8553cfa40e | ||
|
|
30d5c95b26 | ||
|
|
d1e3195e6f | ||
|
|
05a35662ae | ||
|
|
ce53d3a287 | ||
|
|
4cc99e7449 | ||
|
|
71773fe032 | ||
|
|
a1e0fa0f39 | ||
|
|
fc2f0b6983 | ||
|
|
5c9997cdac | ||
|
|
6f81046730 | ||
|
|
0687472d01 | ||
|
|
7739738fb3 | ||
|
|
99d1ce247b | ||
|
|
f5941a411c | ||
|
|
ba672bbd07 | ||
|
|
d9c6627a53 | ||
|
|
2e9907c3ac | ||
|
|
90afb9cb73 | ||
|
|
d0cc0cd9a5 | ||
|
|
338321e553 | ||
|
|
91a2b1f0b4 |
4
.github/workflows/docker-image.yml
vendored
4
.github/workflows/docker-image.yml
vendored
@@ -16,6 +16,8 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -47,6 +49,8 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
|
|||||||
2
.github/workflows/pr-test-build.yml
vendored
2
.github/workflows/pr-test-build.yml
vendored
@@ -12,6 +12,8 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -16,6 +16,8 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
|
||||||
- run: git fetch --force --tags
|
- run: git fetch --force --tags
|
||||||
- uses: actions/setup-go@v4
|
- uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
117
README.md
117
README.md
@@ -8,123 +8,6 @@ All third-party provider support is maintained by community contributors; CLIPro
|
|||||||
|
|
||||||
The Plus release stays in lockstep with the mainline features.
|
The Plus release stays in lockstep with the mainline features.
|
||||||
|
|
||||||
## Differences from the Mainline
|
|
||||||
|
|
||||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
|
||||||
|
|
||||||
## New Features (Plus Enhanced)
|
|
||||||
|
|
||||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
|
||||||
|
|
||||||
## Kiro Authentication
|
|
||||||
|
|
||||||
### CLI Login
|
|
||||||
|
|
||||||
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
|
||||||
|
|
||||||
**AWS Builder ID** (recommended):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Device code flow
|
|
||||||
./CLIProxyAPI --kiro-aws-login
|
|
||||||
|
|
||||||
# Authorization code flow
|
|
||||||
./CLIProxyAPI --kiro-aws-authcode
|
|
||||||
```
|
|
||||||
|
|
||||||
**Import token from Kiro IDE:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-import
|
|
||||||
```
|
|
||||||
|
|
||||||
To get a token from Kiro IDE:
|
|
||||||
|
|
||||||
1. Open Kiro IDE and login with Google (or GitHub)
|
|
||||||
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
|
||||||
3. Run: `./CLIProxyAPI --kiro-import`
|
|
||||||
|
|
||||||
**AWS IAM Identity Center (IDC):**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
|
||||||
|
|
||||||
# Specify region
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
|
||||||
```
|
|
||||||
|
|
||||||
**Additional flags:**
|
|
||||||
|
|
||||||
| Flag | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `--no-browser` | Don't open browser automatically, print URL instead |
|
|
||||||
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
|
||||||
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
|
||||||
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
|
||||||
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
|
||||||
|
|
||||||
### Web-based OAuth Login
|
|
||||||
|
|
||||||
Access the Kiro OAuth web interface at:
|
|
||||||
|
|
||||||
```
|
|
||||||
http://your-server:8080/v0/oauth/kiro
|
|
||||||
```
|
|
||||||
|
|
||||||
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
|
||||||
- AWS Builder ID login
|
|
||||||
- AWS Identity Center (IDC) login
|
|
||||||
- Token import from Kiro IDE
|
|
||||||
|
|
||||||
## Quick Deployment with Docker
|
|
||||||
|
|
||||||
### One-Command Deployment
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Create deployment directory
|
|
||||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
|
||||||
|
|
||||||
# Create docker-compose.yml
|
|
||||||
cat > docker-compose.yml << 'EOF'
|
|
||||||
services:
|
|
||||||
cli-proxy-api:
|
|
||||||
image: eceasy/cli-proxy-api-plus:latest
|
|
||||||
container_name: cli-proxy-api-plus
|
|
||||||
ports:
|
|
||||||
- "8317:8317"
|
|
||||||
volumes:
|
|
||||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
|
||||||
- ./auths:/root/.cli-proxy-api
|
|
||||||
- ./logs:/CLIProxyAPI/logs
|
|
||||||
restart: unless-stopped
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Download example config
|
|
||||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
|
||||||
|
|
||||||
# Pull and start
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
### Configuration
|
|
||||||
|
|
||||||
Edit `config.yaml` before starting:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Basic configuration example
|
|
||||||
server:
|
|
||||||
port: 8317
|
|
||||||
|
|
||||||
# Add your provider configurations here
|
|
||||||
```
|
|
||||||
|
|
||||||
### Update to Latest Version
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ~/cli-proxy
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||||
|
|||||||
117
README_CN.md
117
README_CN.md
@@ -8,123 +8,6 @@
|
|||||||
|
|
||||||
该 Plus 版本的主线功能与主线功能强制同步。
|
该 Plus 版本的主线功能与主线功能强制同步。
|
||||||
|
|
||||||
## 与主线版本版本差异
|
|
||||||
|
|
||||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
|
||||||
|
|
||||||
## 新增功能 (Plus 增强版)
|
|
||||||
|
|
||||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。
|
|
||||||
|
|
||||||
智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
|
||||||
|
|
||||||
### 命令行登录
|
|
||||||
|
|
||||||
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
|
||||||
|
|
||||||
**AWS Builder ID**(推荐):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 设备码流程
|
|
||||||
./CLIProxyAPI --kiro-aws-login
|
|
||||||
|
|
||||||
# 授权码流程
|
|
||||||
./CLIProxyAPI --kiro-aws-authcode
|
|
||||||
```
|
|
||||||
|
|
||||||
**从 Kiro IDE 导入令牌:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-import
|
|
||||||
```
|
|
||||||
|
|
||||||
获取令牌步骤:
|
|
||||||
|
|
||||||
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
|
||||||
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
|
||||||
3. 运行:`./CLIProxyAPI --kiro-import`
|
|
||||||
|
|
||||||
**AWS IAM Identity Center (IDC):**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
|
||||||
|
|
||||||
# 指定区域
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
|
||||||
```
|
|
||||||
|
|
||||||
**附加参数:**
|
|
||||||
|
|
||||||
| 参数 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
|
||||||
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
|
||||||
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
|
||||||
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
|
||||||
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
|
||||||
|
|
||||||
### 网页端 OAuth 登录
|
|
||||||
|
|
||||||
访问 Kiro OAuth 网页认证界面:
|
|
||||||
|
|
||||||
```
|
|
||||||
http://your-server:8080/v0/oauth/kiro
|
|
||||||
```
|
|
||||||
|
|
||||||
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
|
||||||
- AWS Builder ID 登录
|
|
||||||
- AWS Identity Center (IDC) 登录
|
|
||||||
- 从 Kiro IDE 导入令牌
|
|
||||||
|
|
||||||
## Docker 快速部署
|
|
||||||
|
|
||||||
### 一键部署
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 创建部署目录
|
|
||||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
|
||||||
|
|
||||||
# 创建 docker-compose.yml
|
|
||||||
cat > docker-compose.yml << 'EOF'
|
|
||||||
services:
|
|
||||||
cli-proxy-api:
|
|
||||||
image: eceasy/cli-proxy-api-plus:latest
|
|
||||||
container_name: cli-proxy-api-plus
|
|
||||||
ports:
|
|
||||||
- "8317:8317"
|
|
||||||
volumes:
|
|
||||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
|
||||||
- ./auths:/root/.cli-proxy-api
|
|
||||||
- ./logs:/CLIProxyAPI/logs
|
|
||||||
restart: unless-stopped
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# 下载示例配置
|
|
||||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
|
||||||
|
|
||||||
# 拉取并启动
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
### 配置说明
|
|
||||||
|
|
||||||
启动前请编辑 `config.yaml`:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# 基本配置示例
|
|
||||||
server:
|
|
||||||
port: 8317
|
|
||||||
|
|
||||||
# 在此添加你的供应商配置
|
|
||||||
```
|
|
||||||
|
|
||||||
### 更新到最新版本
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ~/cli-proxy
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
## 贡献
|
## 贡献
|
||||||
|
|
||||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||||
@@ -573,6 +574,7 @@ func main() {
|
|||||||
if standalone {
|
if standalone {
|
||||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
hook := tui.NewLogHook(2000)
|
hook := tui.NewLogHook(2000)
|
||||||
hook.SetFormatter(&logging.LogFormatter{})
|
hook.SetFormatter(&logging.LogFormatter{})
|
||||||
log.AddHook(hook)
|
log.AddHook(hook)
|
||||||
@@ -643,15 +645,16 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Start the main proxy service
|
// Start the main proxy service
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
|
|
||||||
if cfg.AuthDir != "" {
|
if cfg.AuthDir != "" {
|
||||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||||
defer kiro.StopGlobalRefreshManager()
|
defer kiro.StopGlobalRefreshManager()
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.StartService(cfg, configFilePath, password)
|
cmd.StartService(cfg, configFilePath, password)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// Package registry provides model definitions and lookup helpers for various AI providers.
|
// Package registry provides model definitions and lookup helpers for various AI providers.
|
||||||
// Static model metadata is stored in model_definitions_static_data.go.
|
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,6 +7,131 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AntigravityModelConfig captures static antigravity model overrides, including
|
||||||
|
// Thinking budget limits and provider max completion tokens.
|
||||||
|
type AntigravityModelConfig struct {
|
||||||
|
Thinking *ThinkingSupport `json:"thinking,omitempty"`
|
||||||
|
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// staticModelsJSON mirrors the top-level structure of models.json.
|
||||||
|
type staticModelsJSON struct {
|
||||||
|
Claude []*ModelInfo `json:"claude"`
|
||||||
|
Gemini []*ModelInfo `json:"gemini"`
|
||||||
|
Vertex []*ModelInfo `json:"vertex"`
|
||||||
|
GeminiCLI []*ModelInfo `json:"gemini-cli"`
|
||||||
|
AIStudio []*ModelInfo `json:"aistudio"`
|
||||||
|
CodexFree []*ModelInfo `json:"codex-free"`
|
||||||
|
CodexTeam []*ModelInfo `json:"codex-team"`
|
||||||
|
CodexPlus []*ModelInfo `json:"codex-plus"`
|
||||||
|
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||||
|
Qwen []*ModelInfo `json:"qwen"`
|
||||||
|
IFlow []*ModelInfo `json:"iflow"`
|
||||||
|
Kimi []*ModelInfo `json:"kimi"`
|
||||||
|
Antigravity map[string]*AntigravityModelConfig `json:"antigravity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaudeModels returns the standard Claude model definitions.
|
||||||
|
func GetClaudeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Claude)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiModels returns the standard Gemini model definitions.
|
||||||
|
func GetGeminiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Gemini)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
|
||||||
|
func GetGeminiVertexModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Vertex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
|
||||||
|
func GetGeminiCLIModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().GeminiCLI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAIStudioModels returns model definitions for AI Studio.
|
||||||
|
func GetAIStudioModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().AIStudio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
|
||||||
|
func GetCodexFreeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexFree)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
|
||||||
|
func GetCodexTeamModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexTeam)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
|
||||||
|
func GetCodexPlusModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPlus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexProModels returns model definitions for the Codex pro plan tier.
|
||||||
|
func GetCodexProModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPro)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQwenModels returns the standard Qwen model definitions.
|
||||||
|
func GetQwenModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Qwen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIFlowModels returns the standard iFlow model definitions.
|
||||||
|
func GetIFlowModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().IFlow)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
|
||||||
|
func GetKimiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Kimi)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||||
|
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||||
|
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||||
|
data := getModels()
|
||||||
|
if len(data.Antigravity) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]*AntigravityModelConfig, len(data.Antigravity))
|
||||||
|
for k, v := range data.Antigravity {
|
||||||
|
out[k] = cloneAntigravityModelConfig(v)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAntigravityModelConfig(cfg *AntigravityModelConfig) *AntigravityModelConfig {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
copyConfig := *cfg
|
||||||
|
if cfg.Thinking != nil {
|
||||||
|
copyThinking := *cfg.Thinking
|
||||||
|
if len(cfg.Thinking.Levels) > 0 {
|
||||||
|
copyThinking.Levels = append([]string(nil), cfg.Thinking.Levels...)
|
||||||
|
}
|
||||||
|
copyConfig.Thinking = ©Thinking
|
||||||
|
}
|
||||||
|
return ©Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||||
|
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*ModelInfo, len(models))
|
||||||
|
for i, m := range models {
|
||||||
|
out[i] = cloneModelInfo(m)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
||||||
// It returns nil when the channel is unknown.
|
// It returns nil when the channel is unknown.
|
||||||
//
|
//
|
||||||
@@ -39,7 +164,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
case "aistudio":
|
case "aistudio":
|
||||||
return GetAIStudioModels()
|
return GetAIStudioModels()
|
||||||
case "codex":
|
case "codex":
|
||||||
return GetOpenAIModels()
|
return GetCodexProModels()
|
||||||
case "qwen":
|
case "qwen":
|
||||||
return GetQwenModels()
|
return GetQwenModels()
|
||||||
case "iflow":
|
case "iflow":
|
||||||
@@ -89,16 +214,17 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data := getModels()
|
||||||
allModels := [][]*ModelInfo{
|
allModels := [][]*ModelInfo{
|
||||||
GetClaudeModels(),
|
data.Claude,
|
||||||
GetGeminiModels(),
|
data.Gemini,
|
||||||
GetGeminiVertexModels(),
|
data.Vertex,
|
||||||
GetGeminiCLIModels(),
|
data.GeminiCLI,
|
||||||
GetAIStudioModels(),
|
data.AIStudio,
|
||||||
GetOpenAIModels(),
|
data.CodexPro,
|
||||||
GetQwenModels(),
|
data.Qwen,
|
||||||
GetIFlowModels(),
|
data.IFlow,
|
||||||
GetKimiModels(),
|
data.Kimi,
|
||||||
GetGitHubCopilotModels(),
|
GetGitHubCopilotModels(),
|
||||||
GetKiroModels(),
|
GetKiroModels(),
|
||||||
GetKiloModels(),
|
GetKiloModels(),
|
||||||
@@ -107,13 +233,13 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
for _, models := range allModels {
|
for _, models := range allModels {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if m != nil && m.ID == modelID {
|
if m != nil && m.ID == modelID {
|
||||||
return m
|
return cloneModelInfo(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check Antigravity static config
|
// Check Antigravity static config
|
||||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
if cfg := cloneAntigravityModelConfig(data.Antigravity[modelID]); cfg != nil {
|
||||||
return &ModelInfo{
|
return &ModelInfo{
|
||||||
ID: modelID,
|
ID: modelID,
|
||||||
Thinking: cfg.Thinking,
|
Thinking: cfg.Thinking,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
198
internal/registry/model_updater.go
Normal file
198
internal/registry/model_updater.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
modelsFetchTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var modelsURLs = []string{
|
||||||
|
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
|
||||||
|
"https://models.router-for.me/models.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:embed models/models.json
|
||||||
|
var embeddedModelsJSON []byte
|
||||||
|
|
||||||
|
type modelStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
data *staticModelsJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelsCatalogStore = &modelStore{}
|
||||||
|
|
||||||
|
var updaterOnce sync.Once
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Load embedded data as fallback on startup.
|
||||||
|
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
|
||||||
|
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartModelsUpdater runs a one-time models refresh on startup.
|
||||||
|
// It blocks until the startup fetch attempt finishes so service initialization
|
||||||
|
// can wait for the refreshed catalog before registering auth-backed models.
|
||||||
|
// Safe to call multiple times; only one refresh will run.
|
||||||
|
func StartModelsUpdater(ctx context.Context) {
|
||||||
|
updaterOnce.Do(func() {
|
||||||
|
runModelsUpdater(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runModelsUpdater(ctx context.Context) {
|
||||||
|
// Try network fetch once on startup, then stop.
|
||||||
|
// Periodic refresh is disabled - models are only refreshed at startup.
|
||||||
|
tryRefreshModels(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryRefreshModels(ctx context.Context) {
|
||||||
|
client := &http.Client{Timeout: modelsFetchTimeout}
|
||||||
|
for _, url := range modelsURLs {
|
||||||
|
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
|
||||||
|
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch request creation failed for %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch failed from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.Body.Close()
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("models fetch read error from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := loadModelsFromBytes(data, url); err != nil {
|
||||||
|
log.Warnf("models parse failed from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("models updated from %s", url)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Warn("models refresh failed from all URLs, using current data")
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadModelsFromBytes(data []byte, source string) error {
|
||||||
|
var parsed staticModelsJSON
|
||||||
|
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||||
|
return fmt.Errorf("%s: decode models catalog: %w", source, err)
|
||||||
|
}
|
||||||
|
if err := validateModelsCatalog(&parsed); err != nil {
|
||||||
|
return fmt.Errorf("%s: validate models catalog: %w", source, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsCatalogStore.mu.Lock()
|
||||||
|
modelsCatalogStore.data = &parsed
|
||||||
|
modelsCatalogStore.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getModels() *staticModelsJSON {
|
||||||
|
modelsCatalogStore.mu.RLock()
|
||||||
|
defer modelsCatalogStore.mu.RUnlock()
|
||||||
|
return modelsCatalogStore.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateModelsCatalog(data *staticModelsJSON) error {
|
||||||
|
if data == nil {
|
||||||
|
return fmt.Errorf("catalog is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
requiredSections := []struct {
|
||||||
|
name string
|
||||||
|
models []*ModelInfo
|
||||||
|
}{
|
||||||
|
{name: "claude", models: data.Claude},
|
||||||
|
{name: "gemini", models: data.Gemini},
|
||||||
|
{name: "vertex", models: data.Vertex},
|
||||||
|
{name: "gemini-cli", models: data.GeminiCLI},
|
||||||
|
{name: "aistudio", models: data.AIStudio},
|
||||||
|
{name: "codex-free", models: data.CodexFree},
|
||||||
|
{name: "codex-team", models: data.CodexTeam},
|
||||||
|
{name: "codex-plus", models: data.CodexPlus},
|
||||||
|
{name: "codex-pro", models: data.CodexPro},
|
||||||
|
{name: "qwen", models: data.Qwen},
|
||||||
|
{name: "iflow", models: data.IFlow},
|
||||||
|
{name: "kimi", models: data.Kimi},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, section := range requiredSections {
|
||||||
|
if err := validateModelSection(section.name, section.models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := validateAntigravitySection(data.Antigravity); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateModelSection(section string, models []*ModelInfo) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return fmt.Errorf("%s section is empty", section)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, len(models))
|
||||||
|
for i, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
return fmt.Errorf("%s[%d] is null", section, i)
|
||||||
|
}
|
||||||
|
modelID := strings.TrimSpace(model.ID)
|
||||||
|
if modelID == "" {
|
||||||
|
return fmt.Errorf("%s[%d] has empty id", section, i)
|
||||||
|
}
|
||||||
|
if _, exists := seen[modelID]; exists {
|
||||||
|
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
|
||||||
|
}
|
||||||
|
seen[modelID] = struct{}{}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAntigravitySection(configs map[string]*AntigravityModelConfig) error {
|
||||||
|
if len(configs) == 0 {
|
||||||
|
return fmt.Errorf("antigravity section is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelID, cfg := range configs {
|
||||||
|
trimmedID := strings.TrimSpace(modelID)
|
||||||
|
if trimmedID == "" {
|
||||||
|
return fmt.Errorf("antigravity contains empty model id")
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return fmt.Errorf("antigravity[%q] is null", trimmedID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
2598
internal/registry/models/models.json
Normal file
2598
internal/registry/models/models.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
} else if system.Type == gjson.String && system.String() != "" {
|
||||||
|
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
|
||||||
|
partJSON, _ = sjson.Set(partJSON, "text", system.String())
|
||||||
|
result += "," + partJSON
|
||||||
}
|
}
|
||||||
result += "]"
|
result += "]"
|
||||||
|
|
||||||
|
|||||||
@@ -842,8 +842,8 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity
|
|||||||
executor := NewClaudeExecutor(&config.Config{})
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
// Inject Accept-Encoding via the custom header attribute mechanism.
|
// Inject Accept-Encoding via the custom header attribute mechanism.
|
||||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
"api_key": "key-123",
|
"api_key": "key-123",
|
||||||
"base_url": server.URL,
|
"base_url": server.URL,
|
||||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||||
}}
|
}}
|
||||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
@@ -980,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
|
|||||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test case 1: String system prompt is preserved and converted to a content block
|
||||||
|
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, false)
|
||||||
|
|
||||||
|
system := gjson.GetBytes(out, "system")
|
||||||
|
if !system.IsArray() {
|
||||||
|
t.Fatalf("system should be an array, got %s", system.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks := system.Array()
|
||||||
|
if len(blocks) != 3 {
|
||||||
|
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
|
||||||
|
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
|
||||||
|
}
|
||||||
|
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
|
||||||
|
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
|
||||||
|
}
|
||||||
|
if blocks[2].Get("text").String() != "You are a helpful assistant." {
|
||||||
|
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||||
|
}
|
||||||
|
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
|
||||||
|
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 2: Strict mode drops the string system prompt
|
||||||
|
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, true)
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 2 {
|
||||||
|
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 3: Empty string system prompt does not produce a spurious block
|
||||||
|
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, false)
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 2 {
|
||||||
|
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 4: Array system prompt is unaffected by the string handling
|
||||||
|
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, false)
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 3 {
|
||||||
|
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
if blocks[2].Get("text").String() != "Be concise." {
|
||||||
|
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 5: Special characters in string system prompt survive conversion
|
||||||
|
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, false)
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 3 {
|
||||||
|
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
|
||||||
|
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -522,9 +522,9 @@ func detectLastConversationRole(body []byte) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch item.Get("type").String() {
|
switch item.Get("type").String() {
|
||||||
case "function_call", "function_call_arguments":
|
case "function_call", "function_call_arguments", "computer_call":
|
||||||
return "assistant"
|
return "assistant"
|
||||||
case "function_call_output", "function_call_response", "tool_result":
|
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
|
||||||
return "tool"
|
return "tool"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -832,6 +832,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
|||||||
if tools.IsArray() {
|
if tools.IsArray() {
|
||||||
for _, tool := range tools.Array() {
|
for _, tool := range tools.Array() {
|
||||||
toolType := tool.Get("type").String()
|
toolType := tool.Get("type").String()
|
||||||
|
if isGitHubCopilotResponsesBuiltinTool(toolType) {
|
||||||
|
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Accept OpenAI format (type="function") and Claude format
|
// Accept OpenAI format (type="function") and Claude format
|
||||||
// (no type field, but has top-level name + input_schema).
|
// (no type field, but has top-level name + input_schema).
|
||||||
if toolType != "" && toolType != "function" {
|
if toolType != "" && toolType != "function" {
|
||||||
@@ -879,6 +883,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
if toolChoice.Type == gjson.JSON {
|
if toolChoice.Type == gjson.JSON {
|
||||||
choiceType := toolChoice.Get("type").String()
|
choiceType := toolChoice.Get("type").String()
|
||||||
|
if isGitHubCopilotResponsesBuiltinTool(choiceType) {
|
||||||
|
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(toolChoice.Raw))
|
||||||
|
return body
|
||||||
|
}
|
||||||
if choiceType == "function" {
|
if choiceType == "function" {
|
||||||
name := toolChoice.Get("name").String()
|
name := toolChoice.Get("name").String()
|
||||||
if name == "" {
|
if name == "" {
|
||||||
@@ -896,6 +904,15 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isGitHubCopilotResponsesBuiltinTool(toolType string) bool {
|
||||||
|
switch strings.TrimSpace(toolType) {
|
||||||
|
case "computer", "computer_use_preview":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func collectTextFromNode(node gjson.Result) string {
|
func collectTextFromNode(node gjson.Result) string {
|
||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
content := m.Get("content")
|
content := m.Get("content")
|
||||||
|
|
||||||
if (role == "system" || role == "developer") && len(arr) > 1 {
|
if (role == "system" || role == "developer") && len(arr) > 1 {
|
||||||
// system -> system_instruction as a user message style
|
// system -> systemInstruction as a user message style
|
||||||
if content.Type == gjson.String {
|
if content.Type == gjson.String {
|
||||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String())
|
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
|
||||||
systemPartIndex++
|
systemPartIndex++
|
||||||
} else if content.IsObject() && content.Get("type").String() == "text" {
|
} else if content.IsObject() && content.Get("type").String() == "text" {
|
||||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
||||||
systemPartIndex++
|
systemPartIndex++
|
||||||
} else if content.IsArray() {
|
} else if content.IsArray() {
|
||||||
contents := content.Array()
|
contents := content.Array()
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||||
for j := 0; j < len(contents); j++ {
|
for j := 0; j < len(contents); j++ {
|
||||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
||||||
systemPartIndex++
|
systemPartIndex++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
if instructions := root.Get("instructions"); instructions.Exists() {
|
if instructions := root.Get("instructions"); instructions.Exists() {
|
||||||
systemInstr := `{"parts":[{"text":""}]}`
|
systemInstr := `{"parts":[{"text":""}]}`
|
||||||
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
|
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
|
||||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert input messages to Gemini contents format
|
// Convert input messages to Gemini contents format
|
||||||
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
if strings.EqualFold(itemRole, "system") {
|
if strings.EqualFold(itemRole, "system") {
|
||||||
if contentArray := item.Get("content"); contentArray.Exists() {
|
if contentArray := item.Get("content"); contentArray.Exists() {
|
||||||
systemInstr := ""
|
systemInstr := ""
|
||||||
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
|
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
|
||||||
systemInstr = systemInstructionResult.Raw
|
systemInstr = systemInstructionResult.Raw
|
||||||
} else {
|
} else {
|
||||||
systemInstr = `{"parts":[]}`
|
systemInstr = `{"parts":[]}`
|
||||||
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
}
|
}
|
||||||
|
|
||||||
if systemInstr != `{"parts":[]}` {
|
if systemInstr != `{"parts":[]}` {
|
||||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
)
|
)
|
||||||
@@ -149,6 +150,16 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||||
|
// For codex auth files, extract plan_type from the JWT id_token.
|
||||||
|
if provider == "codex" {
|
||||||
|
if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" {
|
||||||
|
if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil {
|
||||||
|
if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" {
|
||||||
|
a.Attributes["plan_type"] = pt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if provider == "gemini-cli" {
|
if provider == "gemini-cli" {
|
||||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||||
for _, v := range virtuals {
|
for _, v := range virtuals {
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ const (
|
|||||||
wsTurnStateHeader = "x-codex-turn-state"
|
wsTurnStateHeader = "x-codex-turn-state"
|
||||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||||
wsPayloadLogMaxSize = 2048
|
wsPayloadLogMaxSize = 2048
|
||||||
|
wsBodyLogMaxSize = 64 * 1024
|
||||||
|
wsBodyLogTruncated = "\n[websocket log truncated]\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||||
@@ -825,18 +827,71 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
|
|||||||
if builder == nil {
|
if builder == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if builder.Len() >= wsBodyLogMaxSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
trimmedPayload := bytes.TrimSpace(payload)
|
trimmedPayload := bytes.TrimSpace(payload)
|
||||||
if len(trimmedPayload) == 0 {
|
if len(trimmedPayload) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if builder.Len() > 0 {
|
if builder.Len() > 0 {
|
||||||
builder.WriteString("\n")
|
if !appendWebsocketLogString(builder, "\n") {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
builder.WriteString("websocket.")
|
if !appendWebsocketLogString(builder, "websocket.") {
|
||||||
builder.WriteString(eventType)
|
return
|
||||||
builder.WriteString("\n")
|
}
|
||||||
builder.Write(trimmedPayload)
|
if !appendWebsocketLogString(builder, eventType) {
|
||||||
builder.WriteString("\n")
|
return
|
||||||
|
}
|
||||||
|
if !appendWebsocketLogString(builder, "\n") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
|
||||||
|
appendWebsocketLogString(builder, wsBodyLogTruncated)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
appendWebsocketLogString(builder, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
|
||||||
|
if builder == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
remaining := wsBodyLogMaxSize - builder.Len()
|
||||||
|
if remaining <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(value) <= remaining {
|
||||||
|
builder.WriteString(value)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
builder.WriteString(value[:remaining])
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
|
||||||
|
if builder == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
remaining := wsBodyLogMaxSize - builder.Len()
|
||||||
|
if remaining <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(value) <= remaining {
|
||||||
|
builder.Write(value)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
limit := remaining - reserveForSuffix
|
||||||
|
if limit < 0 {
|
||||||
|
limit = 0
|
||||||
|
}
|
||||||
|
if limit > len(value) {
|
||||||
|
limit = len(value)
|
||||||
|
}
|
||||||
|
builder.Write(value[:limit])
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func websocketPayloadEventType(payload []byte) string {
|
func websocketPayloadEventType(payload []byte) string {
|
||||||
|
|||||||
@@ -266,6 +266,33 @@ func TestAppendWebsocketEvent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
|
||||||
|
|
||||||
|
appendWebsocketEvent(&builder, "request", payload)
|
||||||
|
|
||||||
|
got := builder.String()
|
||||||
|
if len(got) > wsBodyLogMaxSize {
|
||||||
|
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, wsBodyLogTruncated) {
|
||||||
|
t.Fatalf("expected truncation marker in body log")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
|
||||||
|
initial := builder.String()
|
||||||
|
|
||||||
|
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
|
||||||
|
|
||||||
|
if builder.String() != initial {
|
||||||
|
t.Fatalf("builder grew after reaching limit")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSetWebsocketRequestBody(t *testing.T) {
|
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -287,5 +287,8 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl
|
|||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: metadata,
|
Metadata: metadata,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"plan_type": planType,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -213,6 +213,26 @@ func (m *Manager) syncScheduler() {
|
|||||||
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
|
||||||
|
// supportedModelSet is rebuilt from the current global model registry state.
|
||||||
|
// This must be called after models have been registered for a newly added auth,
|
||||||
|
// because the initial scheduler.upsertAuth during Register/Update runs before
|
||||||
|
// registerModelsForAuth and therefore snapshots an empty model set.
|
||||||
|
func (m *Manager) RefreshSchedulerEntry(authID string) {
|
||||||
|
if m == nil || m.scheduler == nil || authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.RLock()
|
||||||
|
auth, ok := m.auths[authID]
|
||||||
|
if !ok || auth == nil {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
snapshot := auth.Clone()
|
||||||
|
m.mu.RUnlock()
|
||||||
|
m.scheduler.upsertAuth(snapshot)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) SetSelector(selector Selector) {
|
func (m *Manager) SetSelector(selector Selector) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return
|
return
|
||||||
@@ -2038,6 +2058,10 @@ func shouldRetrySchedulerPick(err error) bool {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
var cooldownErr *modelCooldownError
|
||||||
|
if errors.As(err, &cooldownErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
var authErr *Error
|
var authErr *Error
|
||||||
if !errors.As(err, &authErr) || authErr == nil {
|
if !errors.As(err, &authErr) || authErr == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
163
sdk/cliproxy/auth/conductor_scheduler_refresh_test.go
Normal file
163
sdk/cliproxy/auth/conductor_scheduler_refresh_test.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schedulerProviderTestExecutor struct {
|
||||||
|
provider string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
prime func(*Manager, *Auth) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "register",
|
||||||
|
prime: func(manager *Manager, auth *Auth) error {
|
||||||
|
_, errRegister := manager.Register(ctx, auth)
|
||||||
|
return errRegister
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update",
|
||||||
|
prime: func(manager *Manager, auth *Auth) error {
|
||||||
|
_, errRegister := manager.Register(ctx, auth)
|
||||||
|
if errRegister != nil {
|
||||||
|
return errRegister
|
||||||
|
}
|
||||||
|
updated := auth.Clone()
|
||||||
|
updated.Metadata = map[string]any{"updated": true}
|
||||||
|
_, errUpdate := manager.Update(ctx, updated)
|
||||||
|
return errUpdate
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "refresh-entry-" + testCase.name,
|
||||||
|
Provider: "gemini",
|
||||||
|
}
|
||||||
|
if errPrime := testCase.prime(manager, auth); errPrime != nil {
|
||||||
|
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
|
||||||
|
}
|
||||||
|
|
||||||
|
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
|
||||||
|
|
||||||
|
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
var authErr *Error
|
||||||
|
if !errors.As(errPick, &authErr) || authErr == nil {
|
||||||
|
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
|
||||||
|
}
|
||||||
|
if authErr.Code != "auth_not_found" {
|
||||||
|
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.RefreshSchedulerEntry(auth.ID)
|
||||||
|
|
||||||
|
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickSingle() after refresh error = %v", errPick)
|
||||||
|
}
|
||||||
|
if got == nil || got.ID != auth.ID {
|
||||||
|
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
|
||||||
|
|
||||||
|
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
|
||||||
|
|
||||||
|
oldAuth := &Auth{
|
||||||
|
ID: "cooldown-stale-old",
|
||||||
|
Provider: "gemini",
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register old auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.MarkResult(ctx, Result{
|
||||||
|
AuthID: oldAuth.ID,
|
||||||
|
Provider: "gemini",
|
||||||
|
Model: "scheduler-cooldown-rebuild-model",
|
||||||
|
Success: false,
|
||||||
|
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
|
||||||
|
})
|
||||||
|
|
||||||
|
newAuth := &Auth{
|
||||||
|
ID: "cooldown-stale-new",
|
||||||
|
Provider: "gemini",
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register new auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient(newAuth.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
var cooldownErr *modelCooldownError
|
||||||
|
if !errors.As(errPick, &cooldownErr) {
|
||||||
|
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickNext() error = %v", errPick)
|
||||||
|
}
|
||||||
|
if executor == nil {
|
||||||
|
t.Fatal("pickNext() executor = nil")
|
||||||
|
}
|
||||||
|
if got == nil || got.ID != newAuth.ID {
|
||||||
|
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -250,17 +250,41 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
|||||||
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
|
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
predicate := triedPredicate(tried)
|
||||||
|
candidateShards := make([]*modelScheduler, len(normalized))
|
||||||
|
bestPriority := 0
|
||||||
|
hasCandidate := false
|
||||||
|
now := time.Now()
|
||||||
|
for providerIndex, providerKey := range normalized {
|
||||||
|
providerState := s.providers[providerKey]
|
||||||
|
if providerState == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
shard := providerState.ensureModelLocked(modelKey, now)
|
||||||
|
candidateShards[providerIndex] = shard
|
||||||
|
if shard == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
|
||||||
|
if !okPriority {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !hasCandidate || priorityReady > bestPriority {
|
||||||
|
bestPriority = priorityReady
|
||||||
|
hasCandidate = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasCandidate {
|
||||||
|
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||||
|
}
|
||||||
|
|
||||||
if s.strategy == schedulerStrategyFillFirst {
|
if s.strategy == schedulerStrategyFillFirst {
|
||||||
for _, providerKey := range normalized {
|
for providerIndex, providerKey := range normalized {
|
||||||
providerState := s.providers[providerKey]
|
shard := candidateShards[providerIndex]
|
||||||
if providerState == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
|
||||||
if shard == nil {
|
if shard == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
picked := shard.pickReadyLocked(false, s.strategy, triedPredicate(tried))
|
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
|
||||||
if picked != nil {
|
if picked != nil {
|
||||||
return picked, providerKey, nil
|
return picked, providerKey, nil
|
||||||
}
|
}
|
||||||
@@ -276,15 +300,11 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
|||||||
for offset := 0; offset < len(normalized); offset++ {
|
for offset := 0; offset < len(normalized); offset++ {
|
||||||
providerIndex := (start + offset) % len(normalized)
|
providerIndex := (start + offset) % len(normalized)
|
||||||
providerKey := normalized[providerIndex]
|
providerKey := normalized[providerIndex]
|
||||||
providerState := s.providers[providerKey]
|
shard := candidateShards[providerIndex]
|
||||||
if providerState == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
|
||||||
if shard == nil {
|
if shard == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
picked := shard.pickReadyLocked(false, schedulerStrategyRoundRobin, triedPredicate(tried))
|
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
|
||||||
if picked == nil {
|
if picked == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -629,6 +649,19 @@ func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedule
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m.promoteExpiredLocked(time.Now())
|
m.promoteExpiredLocked(time.Now())
|
||||||
|
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
|
||||||
|
if !okPriority {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
|
||||||
|
// The caller must ensure expired entries are already promoted when needed.
|
||||||
|
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
|
||||||
|
if m == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
for _, priority := range m.priorityOrder {
|
for _, priority := range m.priorityOrder {
|
||||||
bucket := m.readyByPriority[priority]
|
bucket := m.readyByPriority[priority]
|
||||||
if bucket == nil {
|
if bucket == nil {
|
||||||
@@ -638,17 +671,37 @@ func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedule
|
|||||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||||
view = &bucket.ws
|
view = &bucket.ws
|
||||||
}
|
}
|
||||||
var picked *scheduledAuth
|
if view.pickFirst(predicate) != nil {
|
||||||
if strategy == schedulerStrategyFillFirst {
|
return priority, true
|
||||||
picked = view.pickFirst(predicate)
|
|
||||||
} else {
|
|
||||||
picked = view.pickRoundRobin(predicate)
|
|
||||||
}
|
|
||||||
if picked != nil && picked.auth != nil {
|
|
||||||
return picked.auth
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
|
||||||
|
// The caller must ensure expired entries are already promoted when needed.
|
||||||
|
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bucket := m.readyByPriority[priority]
|
||||||
|
if bucket == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
view := &bucket.all
|
||||||
|
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||||
|
view = &bucket.ws
|
||||||
|
}
|
||||||
|
var picked *scheduledAuth
|
||||||
|
if strategy == schedulerStrategyFillFirst {
|
||||||
|
picked = view.pickFirst(predicate)
|
||||||
|
} else {
|
||||||
|
picked = view.pickRoundRobin(predicate)
|
||||||
|
}
|
||||||
|
if picked == nil || picked.auth == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return picked.auth
|
||||||
}
|
}
|
||||||
|
|
||||||
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||||
|
|||||||
@@ -176,6 +176,25 @@ func BenchmarkManagerPickNextMixed500(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
|
||||||
|
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
|
||||||
|
ctx := context.Background()
|
||||||
|
opts := cliproxyexecutor.Options{}
|
||||||
|
tried := map[string]struct{}{}
|
||||||
|
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
|
||||||
|
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
|
||||||
|
if errPick != nil || auth == nil || exec == nil || provider == "" {
|
||||||
|
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
|
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
|
||||||
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
@@ -237,6 +237,41 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
model := "gpt-default"
|
||||||
|
registerSchedulerModels(t, "provider-low", model, "low")
|
||||||
|
registerSchedulerModels(t, "provider-high-a", model, "high-a")
|
||||||
|
registerSchedulerModels(t, "provider-high-b", model, "high-b")
|
||||||
|
|
||||||
|
scheduler := newSchedulerForTest(
|
||||||
|
&RoundRobinSelector{},
|
||||||
|
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
|
||||||
|
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
|
||||||
|
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
|
||||||
|
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
|
||||||
|
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
|
||||||
|
for index := range wantProviders {
|
||||||
|
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickMixed() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if provider != wantProviders[index] {
|
||||||
|
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||||
|
}
|
||||||
|
if got.ID != wantIDs[index] {
|
||||||
|
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -323,6 +323,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
|||||||
// This operation may block on network calls, but the auth configuration
|
// This operation may block on network calls, but the auth configuration
|
||||||
// is already effective at this point.
|
// is already effective at this point.
|
||||||
s.registerModelsForAuth(auth)
|
s.registerModelsForAuth(auth)
|
||||||
|
|
||||||
|
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
|
||||||
|
// from the now-populated global model registry. Without this, newly added auths
|
||||||
|
// have an empty supportedModelSet (because Register/Update upserts into the
|
||||||
|
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
|
||||||
|
s.coreManager.RefreshSchedulerEntry(auth.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||||
@@ -852,7 +858,22 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
}
|
}
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
case "codex":
|
case "codex":
|
||||||
models = registry.GetOpenAIModels()
|
codexPlanType := ""
|
||||||
|
if a.Attributes != nil {
|
||||||
|
codexPlanType = strings.TrimSpace(a.Attributes["plan_type"])
|
||||||
|
}
|
||||||
|
switch strings.ToLower(codexPlanType) {
|
||||||
|
case "pro":
|
||||||
|
models = registry.GetCodexProModels()
|
||||||
|
case "plus":
|
||||||
|
models = registry.GetCodexPlusModels()
|
||||||
|
case "team":
|
||||||
|
models = registry.GetCodexTeamModels()
|
||||||
|
case "free":
|
||||||
|
models = registry.GetCodexFreeModels()
|
||||||
|
default:
|
||||||
|
models = registry.GetCodexProModels()
|
||||||
|
}
|
||||||
if entry := s.resolveConfigCodexKey(a); entry != nil {
|
if entry := s.resolveConfigCodexKey(a); entry != nil {
|
||||||
if len(entry.Models) > 0 {
|
if len(entry.Models) > 0 {
|
||||||
models = buildCodexConfigModels(entry)
|
models = buildCodexConfigModels(entry)
|
||||||
|
|||||||
Reference in New Issue
Block a user