mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-13 09:13:17 +00:00
Compare commits
104 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34c8ccb961 | ||
|
|
d08e164af3 | ||
|
|
8178efaeda | ||
|
|
86d5db472a | ||
|
|
020d36f6e8 | ||
|
|
1db23979e8 | ||
|
|
c3d5dbe96f | ||
|
|
5484489406 | ||
|
|
0ac52da460 | ||
|
|
817cebb321 | ||
|
|
683f3709d6 | ||
|
|
dbd42a42b2 | ||
|
|
ec24baf757 | ||
|
|
dea3e74d35 | ||
|
|
a6c3042e34 | ||
|
|
861537c9bd | ||
|
|
8c92cb0883 | ||
|
|
89d7be9525 | ||
|
|
2b79d7f22f | ||
|
|
2bb686f594 | ||
|
|
163fe287ce | ||
|
|
70988d387b | ||
|
|
52058a1659 | ||
|
|
df5595a0c9 | ||
|
|
ddaa9d2436 | ||
|
|
7b7b258c38 | ||
|
|
a00f774f5a | ||
|
|
9daf1ba8b5 | ||
|
|
76f2359637 | ||
|
|
dcb1c9be8a | ||
|
|
a24f4ace78 | ||
|
|
c631df8c3b | ||
|
|
54c3eb1b1e | ||
|
|
bb28cd26ad | ||
|
|
046865461e | ||
|
|
cf74ed2f0c | ||
|
|
e333fbea3d | ||
|
|
efbe36d1d4 | ||
|
|
8553cfa40e | ||
|
|
30d5c95b26 | ||
|
|
d1e3195e6f | ||
|
|
05a35662ae | ||
|
|
ce53d3a287 | ||
|
|
4cc99e7449 | ||
|
|
71773fe032 | ||
|
|
a1e0fa0f39 | ||
|
|
fc2f0b6983 | ||
|
|
5c9997cdac | ||
|
|
6f81046730 | ||
|
|
0687472d01 | ||
|
|
7739738fb3 | ||
|
|
99d1ce247b | ||
|
|
f5941a411c | ||
|
|
ba672bbd07 | ||
|
|
d9c6627a53 | ||
|
|
2e9907c3ac | ||
|
|
90afb9cb73 | ||
|
|
d0cc0cd9a5 | ||
|
|
338321e553 | ||
|
|
182b31963a | ||
|
|
4f48e5254a | ||
|
|
15dd5db1d7 | ||
|
|
424711b718 | ||
|
|
91a2b1f0b4 | ||
|
|
2b134fc378 | ||
|
|
b9153719b0 | ||
|
|
631e5c8331 | ||
|
|
e9c60a0a67 | ||
|
|
98a1bb5a7f | ||
|
|
ca90487a8c | ||
|
|
1042489f85 | ||
|
|
38277c1ea6 | ||
|
|
ee0c24628f | ||
|
|
3a18f6fcca | ||
|
|
099e734a02 | ||
|
|
a52da26b5d | ||
|
|
522a68a4ea | ||
|
|
a02eda54d0 | ||
|
|
97ef633c57 | ||
|
|
dae8463ba1 | ||
|
|
7c1299922e | ||
|
|
ddcf1f279d | ||
|
|
7e6bb8fdc5 | ||
|
|
9cee8ef87b | ||
|
|
93fb841bcb | ||
|
|
0c05131aeb | ||
|
|
5ebc58fab4 | ||
|
|
2b609dd891 | ||
|
|
a8cbc68c3e | ||
|
|
11a795a01c | ||
|
|
89c428216e | ||
|
|
2695a99623 | ||
|
|
242aecd924 | ||
|
|
ce8cc1ba33 | ||
|
|
ad5253bd2b | ||
|
|
97fdd2e088 | ||
|
|
9397f7049f | ||
|
|
553d6f50ea | ||
|
|
dd44413ba5 | ||
|
|
10fa0f2062 | ||
|
|
30338ecec4 | ||
|
|
9a37defed3 | ||
|
|
c83a057996 | ||
|
|
b7588428c5 |
8
.github/workflows/docker-image.yml
vendored
8
.github/workflows/docker-image.yml
vendored
@@ -16,6 +16,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Refresh models catalog
|
||||
run: |
|
||||
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to DockerHub
|
||||
@@ -47,6 +51,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Refresh models catalog
|
||||
run: |
|
||||
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to DockerHub
|
||||
|
||||
4
.github/workflows/pr-test-build.yml
vendored
4
.github/workflows/pr-test-build.yml
vendored
@@ -12,6 +12,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Refresh models catalog
|
||||
run: |
|
||||
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
|
||||
4
.github/workflows/release.yaml
vendored
4
.github/workflows/release.yaml
vendored
@@ -16,6 +16,10 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Refresh models catalog
|
||||
run: |
|
||||
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
|
||||
117
README.md
117
README.md
@@ -8,123 +8,6 @@ All third-party provider support is maintained by community contributors; CLIPro
|
||||
|
||||
The Plus release stays in lockstep with the mainline features.
|
||||
|
||||
## Differences from the Mainline
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
## New Features (Plus Enhanced)
|
||||
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
|
||||
## Kiro Authentication
|
||||
|
||||
### CLI Login
|
||||
|
||||
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
|
||||
**AWS Builder ID** (recommended):
|
||||
|
||||
```bash
|
||||
# Device code flow
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# Authorization code flow
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**Import token from Kiro IDE:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
To get a token from Kiro IDE:
|
||||
|
||||
1. Open Kiro IDE and login with Google (or GitHub)
|
||||
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
||||
3. Run: `./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# Specify region
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**Additional flags:**
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--no-browser` | Don't open browser automatically, print URL instead |
|
||||
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
||||
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
||||
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
||||
|
||||
### Web-based OAuth Login
|
||||
|
||||
Access the Kiro OAuth web interface at:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
||||
- AWS Builder ID login
|
||||
- AWS Identity Center (IDC) login
|
||||
- Token import from Kiro IDE
|
||||
|
||||
## Quick Deployment with Docker
|
||||
|
||||
### One-Command Deployment
|
||||
|
||||
```bash
|
||||
# Create deployment directory
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# Create docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# Download example config
|
||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# Pull and start
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Edit `config.yaml` before starting:
|
||||
|
||||
```yaml
|
||||
# Basic configuration example
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# Add your provider configurations here
|
||||
```
|
||||
|
||||
### Update to Latest Version
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||
|
||||
121
README_CN.md
121
README_CN.md
@@ -6,125 +6,6 @@
|
||||
|
||||
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
||||
|
||||
该 Plus 版本的主线功能与主线功能强制同步。
|
||||
|
||||
## 与主线版本版本差异
|
||||
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
|
||||
## 新增功能 (Plus 增强版)
|
||||
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。
|
||||
|
||||
智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
||||
|
||||
### 命令行登录
|
||||
|
||||
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
||||
|
||||
**AWS Builder ID**(推荐):
|
||||
|
||||
```bash
|
||||
# 设备码流程
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# 授权码流程
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**从 Kiro IDE 导入令牌:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
获取令牌步骤:
|
||||
|
||||
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
||||
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
||||
3. 运行:`./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# 指定区域
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**附加参数:**
|
||||
|
||||
| 参数 | 说明 |
|
||||
|------|------|
|
||||
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
||||
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
||||
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
||||
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
||||
|
||||
### 网页端 OAuth 登录
|
||||
|
||||
访问 Kiro OAuth 网页认证界面:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
||||
- AWS Builder ID 登录
|
||||
- AWS Identity Center (IDC) 登录
|
||||
- 从 Kiro IDE 导入令牌
|
||||
|
||||
## Docker 快速部署
|
||||
|
||||
### 一键部署
|
||||
|
||||
```bash
|
||||
# 创建部署目录
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# 创建 docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# 下载示例配置
|
||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# 拉取并启动
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### 配置说明
|
||||
|
||||
启动前请编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# 基本配置示例
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# 在此添加你的供应商配置
|
||||
```
|
||||
|
||||
### 更新到最新版本
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## 贡献
|
||||
|
||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||
@@ -133,4 +14,4 @@ docker compose pull && docker compose up -d
|
||||
|
||||
## 许可证
|
||||
|
||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||
|
||||
275
cmd/fetch_antigravity_models/main.go
Normal file
275
cmd/fetch_antigravity_models/main.go
Normal file
@@ -0,0 +1,275 @@
|
||||
// Command fetch_antigravity_models connects to the Antigravity API using the
|
||||
// stored auth credentials and saves the dynamically fetched model list to a
|
||||
// JSON file for inspection or offline use.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// go run ./cmd/fetch_antigravity_models [flags]
|
||||
//
|
||||
// Flags:
|
||||
//
|
||||
// --auths-dir <path> Directory containing auth JSON files (default: "auths")
|
||||
// --output <path> Output JSON file path (default: "antigravity_models.json")
|
||||
// --pretty Pretty-print the output JSON (default: true)
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
|
||||
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logging.SetupBaseLogger()
|
||||
log.SetLevel(log.InfoLevel)
|
||||
}
|
||||
|
||||
// modelOutput wraps the fetched model list with fetch metadata.
|
||||
type modelOutput struct {
|
||||
Models []modelEntry `json:"models"`
|
||||
}
|
||||
|
||||
// modelEntry contains only the fields we want to keep for static model definitions.
|
||||
type modelEntry struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ContextLength int `json:"context_length,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var authsDir string
|
||||
var outputPath string
|
||||
var pretty bool
|
||||
|
||||
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
|
||||
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
|
||||
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
|
||||
flag.Parse()
|
||||
|
||||
// Resolve relative paths against the working directory.
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if !filepath.IsAbs(authsDir) {
|
||||
authsDir = filepath.Join(wd, authsDir)
|
||||
}
|
||||
if !filepath.IsAbs(outputPath) {
|
||||
outputPath = filepath.Join(wd, outputPath)
|
||||
}
|
||||
|
||||
fmt.Printf("Scanning auth files in: %s\n", authsDir)
|
||||
|
||||
// Load all auth records from the directory.
|
||||
fileStore := sdkauth.NewFileTokenStore()
|
||||
fileStore.SetBaseDir(authsDir)
|
||||
|
||||
ctx := context.Background()
|
||||
auths, err := fileStore.List(ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if len(auths) == 0 {
|
||||
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Find the first enabled antigravity auth.
|
||||
var chosen *coreauth.Auth
|
||||
for _, a := range auths {
|
||||
if a == nil || a.Disabled {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
|
||||
chosen = a
|
||||
break
|
||||
}
|
||||
}
|
||||
if chosen == nil {
|
||||
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
|
||||
|
||||
// Fetch models from the upstream Antigravity API.
|
||||
fmt.Println("Fetching Antigravity model list from upstream...")
|
||||
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
models := fetchModels(fetchCtx, chosen)
|
||||
if len(models) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
|
||||
} else {
|
||||
fmt.Printf("Fetched %d models.\n", len(models))
|
||||
}
|
||||
|
||||
// Build the output payload.
|
||||
out := modelOutput{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
// Marshal to JSON.
|
||||
var raw []byte
|
||||
if pretty {
|
||||
raw, err = json.MarshalIndent(out, "", " ")
|
||||
} else {
|
||||
raw, err = json.Marshal(out)
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Model list saved to: %s\n", outputPath)
|
||||
}
|
||||
|
||||
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
|
||||
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||
if accessToken == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
|
||||
return nil
|
||||
}
|
||||
|
||||
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
|
||||
|
||||
for _, baseURL := range baseURLs {
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
|
||||
var payload []byte
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||
}
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
payload = []byte(`{}`)
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
|
||||
if errReq != nil {
|
||||
continue
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
|
||||
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
|
||||
httpClient.Transport = transport
|
||||
}
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
httpResp.Body.Close()
|
||||
if errRead != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
continue
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
var models []modelEntry
|
||||
|
||||
for originalName, modelData := range result.Map() {
|
||||
modelID := strings.TrimSpace(originalName)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
// Skip internal/experimental models
|
||||
switch modelID {
|
||||
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||
continue
|
||||
}
|
||||
|
||||
displayName := modelData.Get("displayName").String()
|
||||
if displayName == "" {
|
||||
displayName = modelID
|
||||
}
|
||||
|
||||
entry := modelEntry{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
OwnedBy: "antigravity",
|
||||
Type: "antigravity",
|
||||
DisplayName: displayName,
|
||||
Name: modelID,
|
||||
Description: displayName,
|
||||
}
|
||||
|
||||
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||
entry.ContextLength = int(maxTok)
|
||||
}
|
||||
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||
entry.MaxCompletionTokens = int(maxOut)
|
||||
}
|
||||
|
||||
models = append(models, entry)
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func metaStringValue(m map[string]interface{}, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||
@@ -78,6 +79,8 @@ func main() {
|
||||
var kiloLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var gitlabLogin bool
|
||||
var gitlabTokenLogin bool
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
@@ -110,6 +113,8 @@ func main() {
|
||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
|
||||
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||
@@ -526,6 +531,10 @@ func main() {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else if gitlabLogin {
|
||||
cmd.DoGitLabLogin(cfg, options)
|
||||
} else if gitlabTokenLogin {
|
||||
cmd.DoGitLabTokenLogin(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
@@ -573,6 +582,7 @@ func main() {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
hook := tui.NewLogHook(2000)
|
||||
hook.SetFormatter(&logging.LogFormatter{})
|
||||
log.AddHook(hook)
|
||||
@@ -643,15 +653,16 @@ func main() {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +68,8 @@ error-logs-max-files: 10
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ''
|
||||
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
|
||||
proxy-url: ""
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
@@ -115,6 +116,7 @@ nonstream-keepalive-interval: 0
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# models:
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||
@@ -133,6 +135,7 @@ nonstream-keepalive-interval: 0
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# models:
|
||||
# - name: "gpt-5-codex" # upstream model name
|
||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||
@@ -151,6 +154,7 @@ nonstream-keepalive-interval: 0
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
@@ -178,6 +182,14 @@ nonstream-keepalive-interval: 0
|
||||
# runtime-version: "v24.3.0"
|
||||
# timeout: "600"
|
||||
|
||||
# Default headers for Codex OAuth model requests.
|
||||
# These are used only for file-backed/OAuth Codex requests when the client
|
||||
# does not send the header. `user-agent` applies to HTTP and websocket requests;
|
||||
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
|
||||
# codex-header-defaults:
|
||||
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
|
||||
# beta-features: "multi_agent"
|
||||
|
||||
# Kiro (AWS CodeWhisperer) configuration
|
||||
# Note: Kiro API currently only operates in us-east-1 region
|
||||
#kiro:
|
||||
@@ -215,10 +227,22 @@ nonstream-keepalive-interval: 0
|
||||
# api-key-entries:
|
||||
# - api-key: "sk-or-v1-...b780"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||
# models: # The models supported by the provider.
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# # You may repeat the same alias to build an internal model pool.
|
||||
# # The client still sees only one alias in the model list.
|
||||
# # Requests to that alias will round-robin across the upstream names below,
|
||||
# # and if the chosen upstream fails before producing output, the request will
|
||||
# # continue with the next upstream model in the same alias pool.
|
||||
# - name: "qwen3.5-plus"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "glm-5"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "claude-opus-4.66"
|
||||
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
@@ -226,6 +250,7 @@ nonstream-keepalive-interval: 0
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# models: # optional: map aliases to upstream model names
|
||||
|
||||
115
docs/gitlab-duo.md
Normal file
115
docs/gitlab-duo.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# GitLab Duo guide
|
||||
|
||||
CLIProxyAPI can now use GitLab Duo as a first-class provider instead of treating it as a plain text wrapper.
|
||||
|
||||
It supports:
|
||||
|
||||
- OAuth login
|
||||
- personal access token login
|
||||
- automatic refresh of GitLab `direct_access` metadata
|
||||
- dynamic model discovery from GitLab metadata
|
||||
- native GitLab AI gateway routing for Anthropic and OpenAI/Codex managed models
|
||||
- Claude-compatible and OpenAI-compatible downstream APIs
|
||||
|
||||
## What this means
|
||||
|
||||
If GitLab Duo returns an Anthropic-managed model, CLIProxyAPI routes requests through the GitLab AI gateway Anthropic proxy and uses the existing Claude executor path.
|
||||
|
||||
If GitLab Duo returns an OpenAI-managed model, CLIProxyAPI routes requests through the GitLab AI gateway OpenAI proxy and uses the existing Codex/OpenAI executor path.
|
||||
|
||||
That gives GitLab Duo much closer runtime behavior to the built-in `codex` provider:
|
||||
|
||||
- Claude-compatible clients can use GitLab Duo models through `/v1/messages`
|
||||
- OpenAI-compatible clients can use GitLab Duo models through `/v1/chat/completions`
|
||||
- OpenAI Responses clients can use GitLab Duo models through `/v1/responses`
|
||||
|
||||
The model list is not hardcoded. CLIProxyAPI reads the current model metadata from GitLab `direct_access` and registers:
|
||||
|
||||
- a stable alias: `gitlab-duo`
|
||||
- any discovered managed model names, such as `claude-sonnet-4-5` or `gpt-5-codex`
|
||||
|
||||
## Login
|
||||
|
||||
OAuth login:
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI -gitlab-login
|
||||
```
|
||||
|
||||
PAT login:
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI -gitlab-token-login
|
||||
```
|
||||
|
||||
You can also provide inputs through environment variables:
|
||||
|
||||
```bash
|
||||
export GITLAB_BASE_URL=https://gitlab.com
|
||||
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- OAuth requires a GitLab OAuth application.
|
||||
- PAT login requires a personal access token that can call the GitLab APIs used by Duo. In practice, `api` scope is the safe baseline.
|
||||
- Self-managed GitLab instances are supported through `GITLAB_BASE_URL`.
|
||||
|
||||
## Using the models
|
||||
|
||||
After login, start CLIProxyAPI normally and point your client at the local proxy.
|
||||
|
||||
You can select:
|
||||
|
||||
- `gitlab-duo` to use the current Duo-managed model for that account
|
||||
- the discovered provider model name if you want to pin it explicitly
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/v1/models
|
||||
```
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gitlab-duo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
If the GitLab account is currently mapped to an Anthropic model, Claude-compatible clients can use the same account through the Claude handler path. If the account is currently mapped to an OpenAI/Codex model, OpenAI-compatible clients can use `/v1/chat/completions` or `/v1/responses`.
|
||||
|
||||
## How model freshness works
|
||||
|
||||
CLIProxyAPI does not ship a fixed GitLab Duo model catalog.
|
||||
|
||||
Instead, it refreshes GitLab `direct_access` metadata and uses the returned `model_details` and any discovered model list entries to keep the local registry aligned with the current GitLab-managed model assignment.
|
||||
|
||||
This matches GitLab's current public contract better than hardcoding model names.
|
||||
|
||||
## Current scope
|
||||
|
||||
The GitLab Duo provider now has:
|
||||
|
||||
- OAuth and PAT auth flows
|
||||
- runtime refresh of Duo gateway credentials
|
||||
- native Anthropic gateway routing
|
||||
- native OpenAI/Codex gateway routing
|
||||
- handler-level smoke tests for Claude-compatible and OpenAI-compatible paths
|
||||
|
||||
Still out of scope today:
|
||||
|
||||
- websocket or session-specific parity beyond the current HTTP APIs
|
||||
- GitLab-specific IDE features that are not exposed through the public gateway contract
|
||||
|
||||
## References
|
||||
|
||||
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||
- GitLab Agent Assistant and managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||
- GitLab Duo model selection: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||
115
docs/gitlab-duo_CN.md
Normal file
115
docs/gitlab-duo_CN.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# GitLab Duo 使用说明
|
||||
|
||||
CLIProxyAPI 现在可以把 GitLab Duo 当作一等 Provider 来使用,而不是仅仅把它当成简单的文本补全封装。
|
||||
|
||||
当前支持:
|
||||
|
||||
- OAuth 登录
|
||||
- personal access token 登录
|
||||
- 自动刷新 GitLab `direct_access` 元数据
|
||||
- 根据 GitLab 返回的元数据动态发现模型
|
||||
- 针对 Anthropic 和 OpenAI/Codex 托管模型的 GitLab AI gateway 原生路由
|
||||
- Claude 兼容与 OpenAI 兼容下游 API
|
||||
|
||||
## 这意味着什么
|
||||
|
||||
如果 GitLab Duo 返回的是 Anthropic 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 Anthropic 代理转发,并复用现有的 Claude executor 路径。
|
||||
|
||||
如果 GitLab Duo 返回的是 OpenAI 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 OpenAI 代理转发,并复用现有的 Codex/OpenAI executor 路径。
|
||||
|
||||
这让 GitLab Duo 的运行时行为更接近内置的 `codex` Provider:
|
||||
|
||||
- Claude 兼容客户端可以通过 `/v1/messages` 使用 GitLab Duo 模型
|
||||
- OpenAI 兼容客户端可以通过 `/v1/chat/completions` 使用 GitLab Duo 模型
|
||||
- OpenAI Responses 客户端可以通过 `/v1/responses` 使用 GitLab Duo 模型
|
||||
|
||||
模型列表不是硬编码的。CLIProxyAPI 会从 GitLab `direct_access` 中读取当前模型元数据,并注册:
|
||||
|
||||
- 一个稳定别名:`gitlab-duo`
|
||||
- GitLab 当前发现到的托管模型名,例如 `claude-sonnet-4-5` 或 `gpt-5-codex`
|
||||
|
||||
## 登录
|
||||
|
||||
OAuth 登录:
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI -gitlab-login
|
||||
```
|
||||
|
||||
PAT 登录:
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI -gitlab-token-login
|
||||
```
|
||||
|
||||
也可以通过环境变量提供输入:
|
||||
|
||||
```bash
|
||||
export GITLAB_BASE_URL=https://gitlab.com
|
||||
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- OAuth 方式需要一个 GitLab OAuth application。
|
||||
- PAT 登录需要一个能够调用 GitLab Duo 相关 API 的 personal access token。实践上,`api` scope 是最稳妥的基线。
|
||||
- 自建 GitLab 实例可以通过 `GITLAB_BASE_URL` 接入。
|
||||
|
||||
## 如何使用模型
|
||||
|
||||
登录完成后,正常启动 CLIProxyAPI,并让客户端连接到本地代理。
|
||||
|
||||
你可以选择:
|
||||
|
||||
- `gitlab-duo`,始终使用该账号当前的 Duo 托管模型
|
||||
- GitLab 当前发现到的 provider 模型名,如果你想显式固定模型
|
||||
|
||||
示例:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/v1/models
|
||||
```
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gitlab-duo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
如果该 GitLab 账号当前绑定的是 Anthropic 模型,Claude 兼容客户端可以通过 Claude handler 路径直接使用它。如果当前绑定的是 OpenAI/Codex 模型,OpenAI 兼容客户端可以通过 `/v1/chat/completions` 或 `/v1/responses` 使用它。
|
||||
|
||||
## 模型如何保持最新
|
||||
|
||||
CLIProxyAPI 不内置固定的 GitLab Duo 模型清单。
|
||||
|
||||
它会刷新 GitLab `direct_access` 元数据,并使用返回的 `model_details` 以及可能存在的模型列表字段,让本地 registry 尽量与 GitLab 当前分配的托管模型保持一致。
|
||||
|
||||
这比硬编码模型名更符合 GitLab 当前公开 API 的实际契约。
|
||||
|
||||
## 当前覆盖范围
|
||||
|
||||
GitLab Duo Provider 目前已经具备:
|
||||
|
||||
- OAuth 和 PAT 登录流程
|
||||
- Duo gateway 凭据的运行时刷新
|
||||
- Anthropic gateway 原生路由
|
||||
- OpenAI/Codex gateway 原生路由
|
||||
- Claude 兼容和 OpenAI 兼容路径的 handler 级 smoke 测试
|
||||
|
||||
当前仍未覆盖:
|
||||
|
||||
- websocket 或 session 级别的完全对齐
|
||||
- GitLab 公开 gateway 契约之外的 IDE 专有能力
|
||||
|
||||
## 参考资料
|
||||
|
||||
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||
- GitLab Agent Assistant 与 managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||
- GitLab Duo 模型选择: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||
278
gitlab-duo-codex-parity-plan.md
Normal file
278
gitlab-duo-codex-parity-plan.md
Normal file
@@ -0,0 +1,278 @@
|
||||
# Plan: GitLab Duo Codex Parity
|
||||
|
||||
**Generated**: 2026-03-10
|
||||
**Estimated Complexity**: High
|
||||
|
||||
## Overview
|
||||
Bring GitLab Duo support from the current "auth + basic executor" stage to the same practical level as `codex` inside `CLIProxyAPI`: a user logs in once, points external clients such as Claude Code at `CLIProxyAPI`, selects GitLab Duo-backed models, and gets stable streaming, multi-turn behavior, tool calling compatibility, and predictable model routing without manual provider-specific workarounds.
|
||||
|
||||
The core architectural shift is to stop treating GitLab Duo as only two REST wrappers (`/api/v4/chat/completions` and `/api/v4/code_suggestions/completions`) and instead use GitLab's `direct_access` contract as the primary runtime entrypoint wherever possible. Official GitLab docs confirm that `direct_access` returns AI gateway connection details, headers, token, and expiry; that contract is the closest path to codex-like provider behavior.
|
||||
|
||||
## Prerequisites
|
||||
- Official GitLab Duo API references confirmed during implementation:
|
||||
- `POST /api/v4/code_suggestions/direct_access`
|
||||
- `POST /api/v4/code_suggestions/completions`
|
||||
- `POST /api/v4/chat/completions`
|
||||
- Access to at least one real GitLab Duo account for manual verification.
|
||||
- One downstream client target for acceptance testing:
|
||||
- Claude Code against Claude-compatible endpoint
|
||||
- OpenAI-compatible client against `/v1/chat/completions` and `/v1/responses`
|
||||
- Existing PR branch as starting point:
|
||||
- `feat/gitlab-duo-auth`
|
||||
- PR [#2028](https://github.com/router-for-me/CLIProxyAPI/pull/2028)
|
||||
|
||||
## Definition Of Done
|
||||
- GitLab Duo models can be used via `CLIProxyAPI` from the same client surfaces that already work for `codex`.
|
||||
- Upstream streaming is real passthrough or faithful chunked forwarding, not synthetic whole-response replay.
|
||||
- Tool/function calling survives translation layers without dropping fields or corrupting names.
|
||||
- Multi-turn and session semantics are stable across `chat/completions`, `responses`, and Claude-compatible routes.
|
||||
- Model exposure stays current from GitLab metadata or gateway discovery without hardcoded stale model tables.
|
||||
- `go test ./...` stays green and at least one real manual end-to-end client flow is documented.
|
||||
|
||||
## Sprint 1: Contract And Gap Closure
|
||||
**Goal**: Replace assumptions with a hard compatibility contract between current `codex` behavior and what GitLab Duo can actually support.
|
||||
|
||||
**Demo/Validation**:
|
||||
- Written matrix showing `codex` features vs current GitLab Duo behavior.
|
||||
- One checked-in developer note or test fixture for real GitLab Duo payload examples.
|
||||
|
||||
### Task 1.1: Freeze Codex Parity Checklist
|
||||
- **Location**: [internal/runtime/executor/codex_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_executor.go), [internal/runtime/executor/codex_websockets_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_websockets_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go)
|
||||
- **Description**: Produce a concrete feature matrix for `codex`: HTTP execute, SSE execute, `/v1/responses`, websocket downstream path, tool calling, request IDs, session close semantics, and model registration behavior.
|
||||
- **Dependencies**: None
|
||||
- **Acceptance Criteria**:
|
||||
- A checklist exists in repo docs or issue notes.
|
||||
- Each capability is marked `required`, `optional`, or `not possible` for GitLab Duo.
|
||||
- **Validation**:
|
||||
- Review against current `codex` code paths.
|
||||
|
||||
### Task 1.2: Lock GitLab Duo Runtime Contract
|
||||
- **Location**: [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||
- **Description**: Validate the exact upstream contract we can rely on:
|
||||
- `direct_access` fields and refresh cadence
|
||||
- whether AI gateway path is usable directly
|
||||
- when `chat/completions` is available vs when fallback is required
|
||||
- what streaming shape is returned by `code_suggestions/completions?stream=true`
|
||||
- **Dependencies**: Task 1.1
|
||||
- **Acceptance Criteria**:
|
||||
- GitLab transport decision is explicit: `gateway-first`, `REST-first`, or `hybrid`.
|
||||
- Unknown areas are isolated behind feature flags, not spread across executor logic.
|
||||
- **Validation**:
|
||||
- Official docs + captured real responses from a Duo account.
|
||||
|
||||
### Task 1.3: Define Client-Facing Compatibility Targets
|
||||
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md), [gitlab-duo-codex-parity-plan.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/gitlab-duo-codex-parity-plan.md)
|
||||
- **Description**: Define exactly which external flows must work to call GitLab Duo support "like codex".
|
||||
- **Dependencies**: Task 1.2
|
||||
- **Acceptance Criteria**:
|
||||
- Required surfaces are listed:
|
||||
- Claude-compatible route
|
||||
- OpenAI `chat/completions`
|
||||
- OpenAI `responses`
|
||||
- optional downstream websocket path
|
||||
- Non-goals are explicit if GitLab upstream cannot support them.
|
||||
- **Validation**:
|
||||
- Maintainer review of stated scope.
|
||||
|
||||
## Sprint 2: Primary Transport Parity
|
||||
**Goal**: Move GitLab Duo execution onto a transport that supports codex-like runtime behavior.
|
||||
|
||||
**Demo/Validation**:
|
||||
- A GitLab Duo model works over real streaming through `/v1/chat/completions`.
|
||||
- No synthetic "collect full body then fake stream" path remains on the primary flow.
|
||||
|
||||
### Task 2.1: Refactor GitLab Executor Into Strategy Layers
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||
- **Description**: Split current executor into explicit strategies:
|
||||
- auth refresh/direct access refresh
|
||||
- gateway transport
|
||||
- GitLab REST fallback transport
|
||||
- downstream translation helpers
|
||||
- **Dependencies**: Sprint 1
|
||||
- **Acceptance Criteria**:
|
||||
- Executor no longer mixes discovery, refresh, fallback selection, and response synthesis in one path.
|
||||
- Transport choice is testable in isolation.
|
||||
- **Validation**:
|
||||
- Unit tests for strategy selection and fallback boundaries.
|
||||
|
||||
### Task 2.2: Implement Real Streaming Path
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go)
|
||||
- **Description**: Replace synthetic streaming with true upstream incremental forwarding:
|
||||
- use gateway stream if available
|
||||
- otherwise consume GitLab Code Suggestions streaming response and map chunks incrementally
|
||||
- **Dependencies**: Task 2.1
|
||||
- **Acceptance Criteria**:
|
||||
- `ExecuteStream` emits chunks before upstream completion.
|
||||
- error handling preserves status and early failure semantics.
|
||||
- **Validation**:
|
||||
- tests with chunked upstream server
|
||||
- manual curl check against `/v1/chat/completions` with `stream=true`
|
||||
|
||||
### Task 2.3: Preserve Upstream Auth And Headers Correctly
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go)
|
||||
- **Description**: Use `direct_access` connection details as first-class transport state:
|
||||
- gateway token
|
||||
- expiry
|
||||
- mandatory forwarded headers
|
||||
- model metadata
|
||||
- **Dependencies**: Task 2.1
|
||||
- **Acceptance Criteria**:
|
||||
- executor stops ignoring gateway headers/token when transport requires them
|
||||
- refresh logic never over-fetches `direct_access`
|
||||
- **Validation**:
|
||||
- tests verifying propagated headers and refresh interval behavior
|
||||
|
||||
## Sprint 3: Request/Response Semantics Parity
|
||||
**Goal**: Make GitLab Duo behave correctly under the same request shapes that current `codex` consumers send.
|
||||
|
||||
**Demo/Validation**:
|
||||
- OpenAI and Claude-compatible clients can do non-streaming and streaming conversations without losing structure.
|
||||
|
||||
### Task 3.1: Normalize Multi-Turn Message Mapping
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/translator](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/translator)
|
||||
- **Description**: Replace the current "flatten prompt into one instruction" behavior with stable multi-turn mapping:
|
||||
- preserve system context
|
||||
- preserve user/assistant ordering
|
||||
- maintain bounded context truncation
|
||||
- **Dependencies**: Sprint 2
|
||||
- **Acceptance Criteria**:
|
||||
- multi-turn requests are not collapsed into a lossy single string unless fallback mode explicitly requires it
|
||||
- truncation policy is deterministic and tested
|
||||
- **Validation**:
|
||||
- golden tests for request mapping
|
||||
|
||||
### Task 3.2: Tool Calling Compatibility Layer
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go)
|
||||
- **Description**: Decide and implement one of two paths:
|
||||
- native pass-through if GitLab gateway supports tool/function structures
|
||||
- strict downgrade path with explicit unsupported errors instead of silent field loss
|
||||
- **Dependencies**: Task 3.1
|
||||
- **Acceptance Criteria**:
|
||||
- tool-related fields are either preserved correctly or rejected explicitly
|
||||
- no silent corruption of tool names, tool calls, or tool results
|
||||
- **Validation**:
|
||||
- table-driven tests for tool payloads
|
||||
- one manual client scenario using tools
|
||||
|
||||
### Task 3.3: Token Counting And Usage Reporting Fidelity
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/usage_helpers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/usage_helpers.go)
|
||||
- **Description**: Improve token/usage reporting so GitLab models behave like first-class providers in logs and scheduling.
|
||||
- **Dependencies**: Sprint 2
|
||||
- **Acceptance Criteria**:
|
||||
- `CountTokens` uses the closest supported estimation path
|
||||
- usage logging distinguishes prompt vs completion when possible
|
||||
- **Validation**:
|
||||
- unit tests for token estimation outputs
|
||||
|
||||
## Sprint 4: Responses And Session Parity
|
||||
**Goal**: Reach codex-level support for OpenAI Responses clients and long-lived sessions where GitLab upstream permits it.
|
||||
|
||||
**Demo/Validation**:
|
||||
- `/v1/responses` works with GitLab Duo in a realistic client flow.
|
||||
- If websocket parity is not possible, the code explicitly declines it and keeps HTTP paths stable.
|
||||
|
||||
### Task 4.1: Make GitLab Compatible With `/v1/responses`
|
||||
- **Location**: [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||
- **Description**: Ensure GitLab transport can safely back the Responses API path, including compact responses if applicable.
|
||||
- **Dependencies**: Sprint 3
|
||||
- **Acceptance Criteria**:
|
||||
- GitLab Duo can be selected behind `/v1/responses`
|
||||
- response IDs and follow-up semantics are defined
|
||||
- **Validation**:
|
||||
- handler tests analogous to codex/openai responses tests
|
||||
|
||||
### Task 4.2: Evaluate Downstream Websocket Parity
|
||||
- **Location**: [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||
- **Description**: Decide whether GitLab Duo can support downstream websocket sessions like codex:
|
||||
- if yes, add session-aware execution path
|
||||
- if no, mark GitLab auth as websocket-ineligible and keep HTTP routes first-class
|
||||
- **Dependencies**: Task 4.1
|
||||
- **Acceptance Criteria**:
|
||||
- websocket behavior is explicit, not accidental
|
||||
- no route claims websocket support when the upstream cannot honor it
|
||||
- **Validation**:
|
||||
- websocket handler tests or explicit capability tests
|
||||
|
||||
### Task 4.3: Add Session Cleanup And Failure Recovery Semantics
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/cliproxy/auth/conductor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/auth/conductor.go)
|
||||
- **Description**: Add codex-like session cleanup, retry boundaries, and model suspension/resume behavior for GitLab failures and quota events.
|
||||
- **Dependencies**: Sprint 2
|
||||
- **Acceptance Criteria**:
|
||||
- auth/model cooldown behavior is predictable on GitLab 4xx/5xx/quota responses
|
||||
- executor cleans up per-session resources if any are introduced
|
||||
- **Validation**:
|
||||
- tests for quota and retry behavior
|
||||
|
||||
## Sprint 5: Client UX, Model UX, And Manual E2E
|
||||
**Goal**: Make GitLab Duo feel like a normal built-in provider to operators and downstream clients.
|
||||
|
||||
**Demo/Validation**:
|
||||
- A documented setup exists for "login once, point Claude Code at CLIProxyAPI, use GitLab Duo-backed model".
|
||||
|
||||
### Task 5.1: Model Alias And Provider UX Cleanup
|
||||
- **Location**: [sdk/cliproxy/service.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/service.go), [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
|
||||
- **Description**: Normalize what users see:
|
||||
- stable alias such as `gitlab-duo`
|
||||
- discovered upstream model names
|
||||
- optional prefix behavior
|
||||
- account labels that clearly distinguish OAuth vs PAT
|
||||
- **Dependencies**: Sprint 3
|
||||
- **Acceptance Criteria**:
|
||||
- users can select a stable GitLab alias even when upstream model changes
|
||||
- dynamic model discovery does not cause confusing model churn
|
||||
- **Validation**:
|
||||
- registry tests and manual `/v1/models` inspection
|
||||
|
||||
### Task 5.2: Add Real End-To-End Acceptance Tests
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go), [sdk/api/handlers/openai](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai)
|
||||
- **Description**: Add higher-level tests covering the actual proxy surfaces:
|
||||
- OpenAI `chat/completions`
|
||||
- OpenAI `responses`
|
||||
- Claude-compatible request path if GitLab is routed there
|
||||
- **Dependencies**: Sprint 4
|
||||
- **Acceptance Criteria**:
|
||||
- tests fail if streaming regresses into synthetic buffering again
|
||||
- tests cover at least one tool-related request and one multi-turn request
|
||||
- **Validation**:
|
||||
- `go test ./...`
|
||||
|
||||
### Task 5.3: Publish Operator Documentation
|
||||
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
|
||||
- **Description**: Document:
|
||||
- OAuth setup requirements
|
||||
- PAT requirements
|
||||
- current capability matrix
|
||||
- known limitations if websocket/tool parity is partial
|
||||
- **Dependencies**: Sprint 5.1
|
||||
- **Acceptance Criteria**:
|
||||
- setup instructions are enough for a new user to reproduce the GitLab Duo flow
|
||||
- limitations are explicit
|
||||
- **Validation**:
|
||||
- dry-run docs review from a clean environment
|
||||
|
||||
## Testing Strategy
|
||||
- Keep `go test ./...` green after every committable task.
|
||||
- Add table-driven tests first for request mapping, refresh behavior, and dynamic model registration.
|
||||
- Add transport tests with `httptest.Server` for:
|
||||
- real chunked streaming
|
||||
- header propagation from `direct_access`
|
||||
- upstream fallback rules
|
||||
- Add at least one manual acceptance checklist:
|
||||
- login via OAuth
|
||||
- login via PAT
|
||||
- list models
|
||||
- run one streaming prompt via OpenAI route
|
||||
- run one prompt from the target downstream client
|
||||
|
||||
## Potential Risks & Gotchas
|
||||
- GitLab public docs expose `direct_access`, but do not fully document every possible AI gateway path. We should isolate any empirically discovered gateway assumptions behind one transport layer and feature flags.
|
||||
- `chat/completions` availability differs by GitLab offering and version. The executor must not assume it always exists.
|
||||
- Code Suggestions is completion-oriented; lossy mapping from rich chat/tool payloads will make GitLab Duo feel worse than codex unless explicitly handled.
|
||||
- Synthetic streaming is not good enough for codex parity and will cause regressions in interactive clients.
|
||||
- Dynamic model discovery can create unstable UX if the stable alias and discovered model IDs are not separated cleanly.
|
||||
- PAT auth may validate successfully while still lacking effective Duo permissions. Error reporting must surface this explicitly.
|
||||
|
||||
## Rollback Plan
|
||||
- Keep the current basic GitLab executor behind a fallback mode until the new transport path is stable.
|
||||
- If parity work destabilizes existing providers, revert only GitLab-specific executor changes and leave auth support intact.
|
||||
- Preserve the stable `gitlab-duo` alias so rollback does not break client configuration.
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -14,13 +13,12 @@ import (
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const defaultAPICallTimeout = 60 * time.Second
|
||||
@@ -725,47 +723,12 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
||||
}
|
||||
|
||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
proxyStr = strings.TrimSpace(proxyStr)
|
||||
if proxyStr == "" {
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||
if errBuild != nil {
|
||||
log.WithError(errBuild).Debug("build proxy transport failed")
|
||||
return nil
|
||||
}
|
||||
|
||||
proxyURL, errParse := url.Parse(proxyStr)
|
||||
if errParse != nil {
|
||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
||||
return nil
|
||||
}
|
||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
||||
log.Debug("proxy URL missing scheme/host")
|
||||
return nil
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
var proxyAuth *proxy.Auth
|
||||
if proxyURL.User != nil {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
||||
return nil
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
return nil
|
||||
return transport
|
||||
}
|
||||
|
||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||
|
||||
@@ -1,173 +1,58 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
type memoryAuthStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]*coreauth.Auth
|
||||
}
|
||||
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||
for _, a := range s.items {
|
||||
out = append(out, a.Clone())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
_ = ctx
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.items == nil {
|
||||
s.items = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
s.items[auth.ID] = auth.Clone()
|
||||
s.mu.Unlock()
|
||||
return auth.ID, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
delete(s.items, id)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
||||
t.Fatalf("unexpected content-type: %s", ct)
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
values, err := url.ParseQuery(string(bodyBytes))
|
||||
if err != nil {
|
||||
t.Fatalf("parse form: %v", err)
|
||||
}
|
||||
if values.Get("grant_type") != "refresh_token" {
|
||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
||||
}
|
||||
if values.Get("refresh_token") != "rt" {
|
||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
||||
}
|
||||
if values.Get("client_id") != antigravityOAuthClientID {
|
||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
||||
}
|
||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
||||
t.Fatalf("unexpected client_secret")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-token",
|
||||
"refresh_token": "rt2",
|
||||
"expires_in": int64(3600),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-test.json",
|
||||
FileName: "antigravity-test.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "rt",
|
||||
"expires_in": int64(3600),
|
||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
h := &Handler{
|
||||
cfg: &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||
},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
h := &Handler{authManager: manager}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
|
||||
httpTransport, ok := transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token, got %q", token)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID(auth.ID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth in manager after update")
|
||||
}
|
||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
||||
if httpTransport.Proxy != nil {
|
||||
t.Fatal("expected direct transport to disable proxy function")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-valid.json",
|
||||
FileName: "antigravity-valid.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "ok-token",
|
||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||
h := &Handler{
|
||||
cfg: &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||
},
|
||||
}
|
||||
h := &Handler{}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
|
||||
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
|
||||
httpTransport, ok := transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||
}
|
||||
if token != "ok-token" {
|
||||
t.Fatalf("expected existing token, got %q", token)
|
||||
|
||||
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
if errRequest != nil {
|
||||
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||
}
|
||||
if callCount != 0 {
|
||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
||||
|
||||
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||
if errProxy != nil {
|
||||
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||
}
|
||||
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
|
||||
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
@@ -54,6 +55,8 @@ const (
|
||||
codexCallbackPort = 1455
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
gitLabLoginModeOAuth = "oauth"
|
||||
gitLabLoginModePAT = "pat"
|
||||
)
|
||||
|
||||
type callbackForwarder struct {
|
||||
@@ -999,6 +1002,165 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
func gitLabBaseURLFromRequest(c *gin.Context) string {
|
||||
if c != nil {
|
||||
if raw := strings.TrimSpace(c.Query("base_url")); raw != "" {
|
||||
return gitlabauth.NormalizeBaseURL(raw)
|
||||
}
|
||||
}
|
||||
if raw := strings.TrimSpace(os.Getenv("GITLAB_BASE_URL")); raw != "" {
|
||||
return gitlabauth.NormalizeBaseURL(raw)
|
||||
}
|
||||
return gitlabauth.DefaultBaseURL
|
||||
}
|
||||
|
||||
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
|
||||
metadata := map[string]any{
|
||||
"type": "gitlab",
|
||||
"auth_method": strings.TrimSpace(mode),
|
||||
"base_url": gitlabauth.NormalizeBaseURL(baseURL),
|
||||
"last_refresh": time.Now().UTC().Format(time.RFC3339),
|
||||
"refresh_interval_seconds": 240,
|
||||
}
|
||||
if tokenResp != nil {
|
||||
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
|
||||
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
|
||||
metadata["refresh_token"] = refreshToken
|
||||
}
|
||||
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
|
||||
metadata["token_type"] = tokenType
|
||||
}
|
||||
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
|
||||
metadata["scope"] = scope
|
||||
}
|
||||
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
|
||||
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
mergeGitLabDirectAccessMetadata(metadata, direct)
|
||||
return metadata
|
||||
}
|
||||
|
||||
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
|
||||
if metadata == nil || direct == nil {
|
||||
return
|
||||
}
|
||||
if base := strings.TrimSpace(direct.BaseURL); base != "" {
|
||||
metadata["duo_gateway_base_url"] = base
|
||||
}
|
||||
if token := strings.TrimSpace(direct.Token); token != "" {
|
||||
metadata["duo_gateway_token"] = token
|
||||
}
|
||||
if direct.ExpiresAt > 0 {
|
||||
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
|
||||
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
|
||||
now := time.Now().UTC()
|
||||
if ttl := expiry.Sub(now); ttl > 0 {
|
||||
interval := int(ttl.Seconds()) / 2
|
||||
switch {
|
||||
case interval < 60:
|
||||
interval = 60
|
||||
case interval > 240:
|
||||
interval = 240
|
||||
}
|
||||
metadata["refresh_interval_seconds"] = interval
|
||||
}
|
||||
}
|
||||
if len(direct.Headers) > 0 {
|
||||
headers := make(map[string]string, len(direct.Headers))
|
||||
for key, value := range direct.Headers {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
headers[key] = value
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
metadata["duo_gateway_headers"] = headers
|
||||
}
|
||||
}
|
||||
if direct.ModelDetails != nil {
|
||||
modelDetails := map[string]any{}
|
||||
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
|
||||
modelDetails["model_provider"] = provider
|
||||
metadata["model_provider"] = provider
|
||||
}
|
||||
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
|
||||
modelDetails["model_name"] = model
|
||||
metadata["model_name"] = model
|
||||
}
|
||||
if len(modelDetails) > 0 {
|
||||
metadata["model_details"] = modelDetails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func primaryGitLabEmail(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
if value := strings.TrimSpace(user.Email); value != "" {
|
||||
return value
|
||||
}
|
||||
return strings.TrimSpace(user.PublicEmail)
|
||||
}
|
||||
|
||||
func gitLabAccountIdentifier(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return "user"
|
||||
}
|
||||
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return "user"
|
||||
}
|
||||
|
||||
func sanitizeGitLabFileName(value string) string {
|
||||
value = strings.TrimSpace(strings.ToLower(value))
|
||||
if value == "" {
|
||||
return "user"
|
||||
}
|
||||
var builder strings.Builder
|
||||
lastDash := false
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r >= '0' && r <= '9':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
default:
|
||||
if !lastDash {
|
||||
builder.WriteRune('-')
|
||||
lastDash = true
|
||||
}
|
||||
}
|
||||
}
|
||||
result := strings.Trim(builder.String(), "-")
|
||||
if result == "" {
|
||||
return "user"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func maskGitLabToken(token string) string {
|
||||
trimmed := strings.TrimSpace(token)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if len(trimmed) <= 8 {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
@@ -1312,12 +1474,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
@@ -1326,7 +1488,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
ts.Auto = false
|
||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
@@ -1337,19 +1499,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1362,13 +1524,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1549,6 +1711,263 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestGitLabToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing GitLab Duo authentication...")
|
||||
|
||||
baseURL := gitLabBaseURLFromRequest(c)
|
||||
clientID := strings.TrimSpace(c.Query("client_id"))
|
||||
clientSecret := strings.TrimSpace(c.Query("client_secret"))
|
||||
if clientID == "" {
|
||||
clientID = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_ID"))
|
||||
}
|
||||
if clientSecret == "" {
|
||||
clientSecret = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_SECRET"))
|
||||
}
|
||||
if clientID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "gitlab client_id is required"})
|
||||
return
|
||||
}
|
||||
|
||||
pkceCodes, err := gitlabauth.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate GitLab PKCE codes: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
||||
return
|
||||
}
|
||||
|
||||
state, err := misc.GenerateRandomState()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate GitLab state parameter: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := gitlabauth.RedirectURL(gitlabauth.DefaultCallbackPort)
|
||||
authClient := gitlabauth.NewAuthClient(h.cfg)
|
||||
authURL, err := authClient.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate GitLab authorization URL: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "gitlab")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/gitlab/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute gitlab callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(gitlabauth.DefaultCallbackPort, "gitlab", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start gitlab callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarderInstance(gitlabauth.DefaultCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gitlab-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var code string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "gitlab") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("gitlab oauth flow timed out")
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return
|
||||
}
|
||||
if data, errRead := os.ReadFile(waitFile); errRead == nil {
|
||||
var payload map[string]string
|
||||
_ = json.Unmarshal(data, &payload)
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
SetOAuthSessionError(state, errStr)
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != state {
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
return
|
||||
}
|
||||
code = strings.TrimSpace(payload["code"])
|
||||
if code == "" {
|
||||
SetOAuthSessionError(state, "Authorization code missing")
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
tokenResp, errExchange := authClient.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, code, pkceCodes.CodeVerifier)
|
||||
if errExchange != nil {
|
||||
log.Errorf("Failed to exchange GitLab authorization code: %v", errExchange)
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
return
|
||||
}
|
||||
|
||||
user, errUser := authClient.GetCurrentUser(ctx, baseURL, tokenResp.AccessToken)
|
||||
if errUser != nil {
|
||||
log.Errorf("Failed to fetch GitLab user profile: %v", errUser)
|
||||
SetOAuthSessionError(state, "Failed to fetch account profile")
|
||||
return
|
||||
}
|
||||
|
||||
direct, errDirect := authClient.FetchDirectAccess(ctx, baseURL, tokenResp.AccessToken)
|
||||
if errDirect != nil {
|
||||
log.Errorf("Failed to fetch GitLab direct access metadata: %v", errDirect)
|
||||
SetOAuthSessionError(state, "Failed to fetch GitLab Duo access")
|
||||
return
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata["oauth_client_id"] = clientID
|
||||
if clientSecret != "" {
|
||||
metadata["oauth_client_secret"] = clientSecret
|
||||
}
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := primaryGitLabEmail(user); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "gitlab",
|
||||
FileName: fileName,
|
||||
Label: identifier,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save GitLab auth record: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("GitLab Duo authentication successful. Token saved to %s\n", savedPath)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("gitlab")
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestGitLabPATToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
var payload struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
PersonalAccessToken string `json:"personal_access_token"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := gitlabauth.NormalizeBaseURL(strings.TrimSpace(payload.BaseURL))
|
||||
if baseURL == "" {
|
||||
baseURL = gitLabBaseURLFromRequest(nil)
|
||||
}
|
||||
pat := strings.TrimSpace(payload.PersonalAccessToken)
|
||||
if pat == "" {
|
||||
pat = strings.TrimSpace(payload.Token)
|
||||
}
|
||||
if pat == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "personal_access_token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
authClient := gitlabauth.NewAuthClient(h.cfg)
|
||||
|
||||
user, err := authClient.GetCurrentUser(ctx, baseURL, pat)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
|
||||
return
|
||||
}
|
||||
patSelf, err := authClient.GetPersonalAccessTokenSelf(ctx, baseURL, pat)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
|
||||
return
|
||||
}
|
||||
direct, err := authClient.FetchDirectAccess(ctx, baseURL, pat)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
|
||||
metadata["auth_kind"] = "personal_access_token"
|
||||
metadata["personal_access_token"] = pat
|
||||
metadata["token_preview"] = maskGitLabToken(pat)
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := primaryGitLabEmail(user); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
if patSelf != nil {
|
||||
if name := strings.TrimSpace(patSelf.Name); name != "" {
|
||||
metadata["pat_name"] = name
|
||||
}
|
||||
if len(patSelf.Scopes) > 0 {
|
||||
metadata["pat_scopes"] = append([]string(nil), patSelf.Scopes...)
|
||||
}
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "gitlab",
|
||||
FileName: fileName,
|
||||
Label: identifier + " (PAT)",
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
savedPath, err := h.saveTokenRecord(ctx, record)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to save authentication tokens"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": "ok",
|
||||
"saved_path": savedPath,
|
||||
"username": strings.TrimSpace(user.Username),
|
||||
"email": primaryGitLabEmail(user),
|
||||
"token_label": identifier,
|
||||
}
|
||||
if direct != nil && direct.ModelDetails != nil {
|
||||
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
|
||||
response["model_provider"] = provider
|
||||
}
|
||||
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
|
||||
response["model_name"] = model
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("GitLab Duo PAT authentication successful. Token saved to %s\n", savedPath)
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
|
||||
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/api/v4/user":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": 42,
|
||||
"username": "gitlab-user",
|
||||
"name": "GitLab User",
|
||||
"email": "gitlab@example.com",
|
||||
})
|
||||
case "/api/v4/personal_access_tokens/self":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": 7,
|
||||
"name": "management-center",
|
||||
"scopes": []string{"api", "read_user"},
|
||||
"user_id": 42,
|
||||
})
|
||||
case "/api/v4/code_suggestions/direct_access":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1893456000,
|
||||
"headers": map[string]string{
|
||||
"X-Gitlab-Realm": "saas",
|
||||
},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
|
||||
h.tokenStore = store
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.RequestGitLabPATToken(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if got := resp["status"]; got != "ok" {
|
||||
t.Fatalf("status = %#v, want ok", got)
|
||||
}
|
||||
if got := resp["model_provider"]; got != "anthropic" {
|
||||
t.Fatalf("model_provider = %#v, want anthropic", got)
|
||||
}
|
||||
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
|
||||
}
|
||||
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
if len(store.items) != 1 {
|
||||
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
|
||||
}
|
||||
var saved *coreauth.Auth
|
||||
for _, item := range store.items {
|
||||
saved = item
|
||||
}
|
||||
if saved == nil {
|
||||
t.Fatal("expected saved auth record")
|
||||
}
|
||||
if saved.Provider != "gitlab" {
|
||||
t.Fatalf("provider = %q, want gitlab", saved.Provider)
|
||||
}
|
||||
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
|
||||
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
|
||||
}
|
||||
if got := saved.Metadata["model_provider"]; got != "anthropic" {
|
||||
t.Fatalf("saved model_provider = %#v, want anthropic", got)
|
||||
}
|
||||
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
|
||||
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
state := "gitlab-state-123"
|
||||
RegisterOAuthSession(state, "gitlab")
|
||||
t.Cleanup(func() { CompleteOAuthSession(state) })
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.PostOAuthCallback(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("read callback file: %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
t.Fatalf("decode callback payload: %v", err)
|
||||
}
|
||||
if got := payload["code"]; got != "test-code" {
|
||||
t.Fatalf("callback code = %q, want test-code", got)
|
||||
}
|
||||
if got := payload["state"]; got != state {
|
||||
t.Fatalf("callback state = %q, want %q", got, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
|
||||
provider, err := NormalizeOAuthProvider("gitlab")
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
|
||||
}
|
||||
if provider != "gitlab" {
|
||||
t.Fatalf("provider = %q, want gitlab", provider)
|
||||
}
|
||||
}
|
||||
@@ -228,6 +228,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gitlab":
|
||||
return "gitlab", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
|
||||
49
internal/api/handlers/management/test_store_test.go
Normal file
49
internal/api/handlers/management/test_store_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
type memoryAuthStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]*coreauth.Auth
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||
for _, item := range s.items {
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.items == nil {
|
||||
s.items = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
s.items[auth.ID] = auth
|
||||
return auth.ID, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.items, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) SetBaseDir(string) {}
|
||||
@@ -403,6 +403,20 @@ func (s *Server) setupRoutes() {
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/gitlab/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gitlab", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/google/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
@@ -658,6 +672,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
||||
mgmt.GET("/gitlab-auth-url", s.mgmt.RequestGitLabToken)
|
||||
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
|
||||
@@ -4,12 +4,12 @@ package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/proxy"
|
||||
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
|
||||
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
||||
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||
var dialer proxy.Dialer = proxy.Direct
|
||||
if cfg != nil && cfg.ProxyURL != "" {
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
|
||||
} else {
|
||||
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
|
||||
} else {
|
||||
dialer = pDialer
|
||||
}
|
||||
if cfg != nil {
|
||||
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
|
||||
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||
dialer = proxyDialer
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
@@ -20,9 +18,9 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
var transport *http.Transport
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
// Handle SOCKS5 proxy.
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
auth := &proxy.Auth{User: username, Password: password}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||
}
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
// Handle HTTP/HTTPS proxy.
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
if transport != nil {
|
||||
proxyClient := &http.Client{Transport: transport}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||
}
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("%v", errBuild)
|
||||
} else if transport != nil {
|
||||
proxyClient := &http.Client{Transport: transport}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Configure the OAuth2 client.
|
||||
conf := &oauth2.Config{
|
||||
ClientID: ClientID,
|
||||
|
||||
492
internal/auth/gitlab/gitlab.go
Normal file
492
internal/auth/gitlab/gitlab.go
Normal file
@@ -0,0 +1,492 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultBaseURL = "https://gitlab.com"
|
||||
DefaultCallbackPort = 17171
|
||||
defaultOAuthScope = "api read_user"
|
||||
)
|
||||
|
||||
type PKCECodes struct {
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
type OAuthResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
type OAuthServer struct {
|
||||
server *http.Server
|
||||
port int
|
||||
resultChan chan *OAuthResult
|
||||
errorChan chan error
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
PublicEmail string `json:"public_email"`
|
||||
}
|
||||
|
||||
type PersonalAccessTokenSelf struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
ModelProvider string `json:"model_provider"`
|
||||
ModelName string `json:"model_name"`
|
||||
}
|
||||
|
||||
type DirectAccessResponse struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Token string `json:"token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
ModelDetails *ModelDetails `json:"model_details,omitempty"`
|
||||
}
|
||||
|
||||
type DiscoveredModel struct {
|
||||
ModelProvider string
|
||||
ModelName string
|
||||
}
|
||||
|
||||
type AuthClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewAuthClient(cfg *config.Config) *AuthClient {
|
||||
client := &http.Client{}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
return &AuthClient{httpClient: client}
|
||||
}
|
||||
|
||||
func NormalizeBaseURL(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return DefaultBaseURL
|
||||
}
|
||||
if !strings.Contains(value, "://") {
|
||||
value = "https://" + value
|
||||
}
|
||||
value = strings.TrimRight(value, "/")
|
||||
return value
|
||||
}
|
||||
|
||||
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
|
||||
if token == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
|
||||
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
|
||||
}
|
||||
if token.ExpiresIn > 0 {
|
||||
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||
verifierBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(verifierBytes); err != nil {
|
||||
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
|
||||
}
|
||||
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return &PKCECodes{
|
||||
CodeVerifier: verifier,
|
||||
CodeChallenge: challenge,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewOAuthServer(port int) *OAuthServer {
|
||||
return &OAuthServer{
|
||||
port: port,
|
||||
resultChan: make(chan *OAuthResult, 1),
|
||||
errorChan: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("gitlab oauth server already running")
|
||||
}
|
||||
if !s.isPortAvailable() {
|
||||
return fmt.Errorf("port %d is already in use", s.port)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
s.running = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.errorChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.running || s.server == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
s.running = false
|
||||
s.server = nil
|
||||
}()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||
select {
|
||||
case result := <-s.resultChan:
|
||||
return result, nil
|
||||
case err := <-s.errorChan:
|
||||
return nil, err
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
query := r.URL.Query()
|
||||
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||
s.sendResult(&OAuthResult{Error: errParam})
|
||||
http.Error(w, errParam, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
state := strings.TrimSpace(query.Get("state"))
|
||||
if code == "" || state == "" {
|
||||
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
|
||||
http.Error(w, "missing code or state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
|
||||
}
|
||||
|
||||
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||
select {
|
||||
case s.resultChan <- result:
|
||||
default:
|
||||
log.Debug("gitlab oauth result channel full, dropping callback result")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) isPortAvailable() bool {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = listener.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func RedirectURL(port int) string {
|
||||
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
|
||||
}
|
||||
|
||||
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
|
||||
if pkce == nil {
|
||||
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
|
||||
}
|
||||
if strings.TrimSpace(clientID) == "" {
|
||||
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
|
||||
}
|
||||
baseURL = NormalizeBaseURL(baseURL)
|
||||
params := url.Values{
|
||||
"client_id": {strings.TrimSpace(clientID)},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"scope": {defaultOAuthScope},
|
||||
"state": {strings.TrimSpace(state)},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
}
|
||||
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {strings.TrimSpace(clientID)},
|
||||
"code": {strings.TrimSpace(code)},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {strings.TrimSpace(codeVerifier)},
|
||||
}
|
||||
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||
form.Set("client_secret", secret)
|
||||
}
|
||||
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||
}
|
||||
|
||||
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||
}
|
||||
if clientID = strings.TrimSpace(clientID); clientID != "" {
|
||||
form.Set("client_id", clientID)
|
||||
}
|
||||
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||
form.Set("client_secret", secret)
|
||||
}
|
||||
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||
}
|
||||
|
||||
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var token TokenResponse
|
||||
if err := json.Unmarshal(body, &token); err != nil {
|
||||
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var user User
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var pat PersonalAccessTokenSelf
|
||||
if err := json.Unmarshal(body, &pat); err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
|
||||
}
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var direct DirectAccessResponse
|
||||
if err := json.Unmarshal(body, &direct); err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
|
||||
}
|
||||
if direct.Headers == nil {
|
||||
direct.Headers = make(map[string]string)
|
||||
}
|
||||
return &direct, nil
|
||||
}
|
||||
|
||||
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
|
||||
if len(metadata) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
models := make([]DiscoveredModel, 0, 4)
|
||||
seen := make(map[string]struct{})
|
||||
appendModel := func(provider, name string) {
|
||||
provider = strings.TrimSpace(provider)
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(name)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
models = append(models, DiscoveredModel{
|
||||
ModelProvider: provider,
|
||||
ModelName: name,
|
||||
})
|
||||
}
|
||||
|
||||
if raw, ok := metadata["model_details"]; ok {
|
||||
appendDiscoveredModels(raw, appendModel)
|
||||
}
|
||||
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
|
||||
|
||||
for _, key := range []string{"models", "supported_models", "discovered_models"} {
|
||||
if raw, ok := metadata[key]; ok {
|
||||
appendDiscoveredModels(raw, appendModel)
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
|
||||
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
|
||||
if nested, ok := typed["models"]; ok {
|
||||
appendDiscoveredModels(nested, appendModel)
|
||||
}
|
||||
case []any:
|
||||
for _, item := range typed {
|
||||
appendDiscoveredModels(item, appendModel)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range typed {
|
||||
appendModel("", item)
|
||||
}
|
||||
case string:
|
||||
appendModel("", typed)
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(raw any) string {
|
||||
switch typed := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(typed.String())
|
||||
case json.Number:
|
||||
return typed.String()
|
||||
case int:
|
||||
return strconv.Itoa(typed)
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10)
|
||||
case float64:
|
||||
return strconv.FormatInt(int64(typed), 10)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
138
internal/auth/gitlab/gitlab_test.go
Normal file
138
internal/auth/gitlab/gitlab_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthClientGenerateAuthURLIncludesPKCE(t *testing.T) {
|
||||
client := NewAuthClient(nil)
|
||||
pkce, err := GeneratePKCECodes()
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePKCECodes() error = %v", err)
|
||||
}
|
||||
|
||||
rawURL, err := client.GenerateAuthURL("https://gitlab.example.com", "client-id", RedirectURL(17171), "state-123", pkce)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL() error = %v", err)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse(authURL) error = %v", err)
|
||||
}
|
||||
if got := parsed.Path; got != "/oauth/authorize" {
|
||||
t.Fatalf("expected /oauth/authorize path, got %q", got)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("client_id"); got != "client-id" {
|
||||
t.Fatalf("expected client_id, got %q", got)
|
||||
}
|
||||
if got := query.Get("scope"); got != defaultOAuthScope {
|
||||
t.Fatalf("expected scope %q, got %q", defaultOAuthScope, got)
|
||||
}
|
||||
if got := query.Get("code_challenge_method"); got != "S256" {
|
||||
t.Fatalf("expected PKCE method S256, got %q", got)
|
||||
}
|
||||
if got := query.Get("code_challenge"); got == "" {
|
||||
t.Fatal("expected non-empty code_challenge")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthClientExchangeCodeForTokens(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/token" {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm() error = %v", err)
|
||||
}
|
||||
if got := r.Form.Get("grant_type"); got != "authorization_code" {
|
||||
t.Fatalf("expected authorization_code grant, got %q", got)
|
||||
}
|
||||
if got := r.Form.Get("code_verifier"); got != "verifier-123" {
|
||||
t.Fatalf("expected code_verifier, got %q", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"token_type": "Bearer",
|
||||
"scope": "api read_user",
|
||||
"created_at": 1710000000,
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewAuthClient(nil)
|
||||
token, err := client.ExchangeCodeForTokens(context.Background(), srv.URL, "client-id", "client-secret", RedirectURL(17171), "auth-code", "verifier-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCodeForTokens() error = %v", err)
|
||||
}
|
||||
if token.AccessToken != "oauth-access" {
|
||||
t.Fatalf("expected access token, got %q", token.AccessToken)
|
||||
}
|
||||
if token.RefreshToken != "oauth-refresh" {
|
||||
t.Fatalf("expected refresh token, got %q", token.RefreshToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDiscoveredModels(t *testing.T) {
|
||||
models := ExtractDiscoveredModels(map[string]any{
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
"supported_models": []any{
|
||||
map[string]any{"model_provider": "openai", "model_name": "gpt-4.1"},
|
||||
"claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
if len(models) != 2 {
|
||||
t.Fatalf("expected 2 unique models, got %d", len(models))
|
||||
}
|
||||
if models[0].ModelName != "claude-sonnet-4-5" {
|
||||
t.Fatalf("unexpected first model %q", models[0].ModelName)
|
||||
}
|
||||
if models[1].ModelName != "gpt-4.1" {
|
||||
t.Fatalf("unexpected second model %q", models[1].ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchDirectAccessDecodesModelDetails(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); !strings.Contains(got, "token-123") {
|
||||
t.Fatalf("expected bearer token, got %q", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1710003600,
|
||||
"headers": map[string]string{
|
||||
"X-Gitlab-Realm": "saas",
|
||||
},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewAuthClient(nil)
|
||||
direct, err := client.FetchDirectAccess(context.Background(), srv.URL, "token-123")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchDirectAccess() error = %v", err)
|
||||
}
|
||||
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected model details, got %+v", direct.ModelDetails)
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewKiroAuthenticator(),
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
69
internal/cmd/gitlab_login.go
Normal file
69
internal/cmd/gitlab_login.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
)
|
||||
|
||||
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{
|
||||
"login_mode": "oauth",
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("GitLab Duo authentication successful!")
|
||||
}
|
||||
|
||||
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
Metadata: map[string]string{
|
||||
"login_mode": "pat",
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("GitLab Duo PAT authentication successful!")
|
||||
}
|
||||
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
configYAML := []byte(`
|
||||
codex-header-defaults:
|
||||
user-agent: " my-codex-client/1.0 "
|
||||
beta-features: " feature-a,feature-b "
|
||||
`)
|
||||
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfigOptional(configPath, false)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||
}
|
||||
|
||||
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
|
||||
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
|
||||
}
|
||||
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
|
||||
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
|
||||
}
|
||||
}
|
||||
@@ -101,6 +101,10 @@ type Config struct {
|
||||
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
||||
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
||||
|
||||
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
|
||||
// These are used only when the client does not send its own headers.
|
||||
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
|
||||
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
|
||||
@@ -150,6 +154,14 @@ type ClaudeHeaderDefaults struct {
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// CodexHeaderDefaults configures fallback header values injected into Codex
|
||||
// model requests for OAuth/file-backed auth when the client omits them.
|
||||
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
|
||||
type CodexHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
@@ -679,6 +691,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Sanitize Codex keys: drop entries without base-url
|
||||
cfg.SanitizeCodexKeys()
|
||||
|
||||
// Sanitize Codex header defaults.
|
||||
cfg.SanitizeCodexHeaderDefaults()
|
||||
|
||||
// Sanitize Claude key headers
|
||||
cfg.SanitizeClaudeKeys()
|
||||
|
||||
@@ -771,6 +786,16 @@ func payloadRawString(value any) ([]byte, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
|
||||
// configured Codex header fallback values.
|
||||
func (cfg *Config) SanitizeCodexHeaderDefaults() {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
|
||||
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
|
||||
@@ -1,12 +1,105 @@
|
||||
// Package registry provides model definitions and lookup helpers for various AI providers.
|
||||
// Static model metadata is stored in model_definitions_static_data.go.
|
||||
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// staticModelsJSON mirrors the top-level structure of models.json.
|
||||
type staticModelsJSON struct {
|
||||
Claude []*ModelInfo `json:"claude"`
|
||||
Gemini []*ModelInfo `json:"gemini"`
|
||||
Vertex []*ModelInfo `json:"vertex"`
|
||||
GeminiCLI []*ModelInfo `json:"gemini-cli"`
|
||||
AIStudio []*ModelInfo `json:"aistudio"`
|
||||
CodexFree []*ModelInfo `json:"codex-free"`
|
||||
CodexTeam []*ModelInfo `json:"codex-team"`
|
||||
CodexPlus []*ModelInfo `json:"codex-plus"`
|
||||
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||
Qwen []*ModelInfo `json:"qwen"`
|
||||
IFlow []*ModelInfo `json:"iflow"`
|
||||
Kimi []*ModelInfo `json:"kimi"`
|
||||
Antigravity []*ModelInfo `json:"antigravity"`
|
||||
}
|
||||
|
||||
// GetClaudeModels returns the standard Claude model definitions.
|
||||
func GetClaudeModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Claude)
|
||||
}
|
||||
|
||||
// GetGeminiModels returns the standard Gemini model definitions.
|
||||
func GetGeminiModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Gemini)
|
||||
}
|
||||
|
||||
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
|
||||
func GetGeminiVertexModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Vertex)
|
||||
}
|
||||
|
||||
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
|
||||
func GetGeminiCLIModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().GeminiCLI)
|
||||
}
|
||||
|
||||
// GetAIStudioModels returns model definitions for AI Studio.
|
||||
func GetAIStudioModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().AIStudio)
|
||||
}
|
||||
|
||||
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
|
||||
func GetCodexFreeModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexFree)
|
||||
}
|
||||
|
||||
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
|
||||
func GetCodexTeamModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexTeam)
|
||||
}
|
||||
|
||||
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
|
||||
func GetCodexPlusModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexPlus)
|
||||
}
|
||||
|
||||
// GetCodexProModels returns model definitions for the Codex pro plan tier.
|
||||
func GetCodexProModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexPro)
|
||||
}
|
||||
|
||||
// GetQwenModels returns the standard Qwen model definitions.
|
||||
func GetQwenModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Qwen)
|
||||
}
|
||||
|
||||
// GetIFlowModels returns the standard iFlow model definitions.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().IFlow)
|
||||
}
|
||||
|
||||
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
|
||||
func GetKimiModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Kimi)
|
||||
}
|
||||
|
||||
// GetAntigravityModels returns the standard Antigravity model definitions.
|
||||
func GetAntigravityModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Antigravity)
|
||||
}
|
||||
|
||||
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*ModelInfo, len(models))
|
||||
for i, m := range models {
|
||||
out[i] = cloneModelInfo(m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
||||
// It returns nil when the channel is unknown.
|
||||
//
|
||||
@@ -20,7 +113,6 @@ import (
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kimi
|
||||
// - kiro
|
||||
// - kilo
|
||||
// - github-copilot
|
||||
// - amazonq
|
||||
@@ -39,7 +131,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
case "aistudio":
|
||||
return GetAIStudioModels()
|
||||
case "codex":
|
||||
return GetOpenAIModels()
|
||||
return GetCodexProModels()
|
||||
case "qwen":
|
||||
return GetQwenModels()
|
||||
case "iflow":
|
||||
@@ -55,28 +147,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
case "amazonq":
|
||||
return GetAmazonQModels()
|
||||
case "antigravity":
|
||||
cfg := GetAntigravityModelConfig()
|
||||
if len(cfg) == 0 {
|
||||
return nil
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(cfg))
|
||||
for modelID, entry := range cfg {
|
||||
if modelID == "" || entry == nil {
|
||||
continue
|
||||
}
|
||||
models = append(models, &ModelInfo{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
OwnedBy: "antigravity",
|
||||
Type: "antigravity",
|
||||
Thinking: entry.Thinking,
|
||||
MaxCompletionTokens: entry.MaxCompletionTokens,
|
||||
})
|
||||
}
|
||||
sort.Slice(models, func(i, j int) bool {
|
||||
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
|
||||
})
|
||||
return models
|
||||
return GetAntigravityModels()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -89,16 +160,18 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := getModels()
|
||||
allModels := [][]*ModelInfo{
|
||||
GetClaudeModels(),
|
||||
GetGeminiModels(),
|
||||
GetGeminiVertexModels(),
|
||||
GetGeminiCLIModels(),
|
||||
GetAIStudioModels(),
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
GetKimiModels(),
|
||||
data.Claude,
|
||||
data.Gemini,
|
||||
data.Vertex,
|
||||
data.GeminiCLI,
|
||||
data.AIStudio,
|
||||
data.CodexPro,
|
||||
data.Qwen,
|
||||
data.IFlow,
|
||||
data.Kimi,
|
||||
data.Antigravity,
|
||||
GetGitHubCopilotModels(),
|
||||
GetKiroModels(),
|
||||
GetKiloModels(),
|
||||
@@ -107,20 +180,11 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
if m != nil && m.ID == modelID {
|
||||
return m
|
||||
return cloneModelInfo(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check Antigravity static config
|
||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
||||
return &ModelInfo{
|
||||
ID: modelID,
|
||||
Thinking: cfg.Thinking,
|
||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -64,6 +64,11 @@ type ModelInfo struct {
|
||||
UserDefined bool `json:"-"`
|
||||
}
|
||||
|
||||
type availableModelsCacheEntry struct {
|
||||
models []map[string]any
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
||||
// Values are interpreted in provider-native token units.
|
||||
type ThinkingSupport struct {
|
||||
@@ -118,6 +123,8 @@ type ModelRegistry struct {
|
||||
clientProviders map[string]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
mutex *sync.RWMutex
|
||||
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
|
||||
availableModelsCache map[string]availableModelsCacheEntry
|
||||
// hook is an optional callback sink for model registration changes
|
||||
hook ModelRegistryHook
|
||||
}
|
||||
@@ -130,15 +137,28 @@ var registryOnce sync.Once
|
||||
func GetGlobalRegistry() *ModelRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
availableModelsCache: make(map[string]availableModelsCacheEntry),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
|
||||
if r.availableModelsCache == nil {
|
||||
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
|
||||
if len(r.availableModelsCache) == 0 {
|
||||
return
|
||||
}
|
||||
clear(r.availableModelsCache)
|
||||
}
|
||||
|
||||
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||
@@ -153,9 +173,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||
}
|
||||
|
||||
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||
return info
|
||||
return cloneModelInfo(info)
|
||||
}
|
||||
return LookupStaticModelInfo(modelID)
|
||||
return cloneModelInfo(LookupStaticModelInfo(modelID))
|
||||
}
|
||||
|
||||
// SetHook sets an optional hook for observing model registration changes.
|
||||
@@ -169,6 +189,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
|
||||
}
|
||||
|
||||
const defaultModelRegistryHookTimeout = 5 * time.Second
|
||||
const modelQuotaExceededWindow = 5 * time.Minute
|
||||
|
||||
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
||||
hook := r.hook
|
||||
@@ -213,6 +234,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
|
||||
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
provider := strings.ToLower(clientProvider)
|
||||
uniqueModelIDs := make([]string, 0, len(models))
|
||||
@@ -238,6 +260,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
delete(r.clientProviders, clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
misc.LogCredentialSeparator()
|
||||
return
|
||||
}
|
||||
@@ -265,6 +288,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
} else {
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
r.triggerModelsRegistered(provider, clientID, models)
|
||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||
misc.LogCredentialSeparator()
|
||||
@@ -367,6 +391,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
reg.LastUpdated = now
|
||||
// Re-registering an existing client/model binding starts a fresh registry
|
||||
// snapshot for that binding. Cooldown and suspension are transient
|
||||
// scheduling state and must not survive this reconciliation step.
|
||||
if reg.QuotaExceededClients != nil {
|
||||
delete(reg.QuotaExceededClients, clientID)
|
||||
}
|
||||
@@ -408,6 +435,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
r.triggerModelsRegistered(provider, clientID, models)
|
||||
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||
@@ -511,6 +539,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
||||
if len(model.SupportedOutputModalities) > 0 {
|
||||
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
copyThinking := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
copyModel.Thinking = ©Thinking
|
||||
}
|
||||
return ©Model
|
||||
}
|
||||
|
||||
@@ -540,6 +575,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.unregisterClientInternal(clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
}
|
||||
|
||||
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
||||
@@ -606,9 +642,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
@@ -620,9 +659,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
@@ -638,6 +679,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil {
|
||||
@@ -651,6 +693,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
||||
}
|
||||
registration.SuspendedClients[clientID] = reason
|
||||
registration.LastUpdated = time.Now()
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
if reason != "" {
|
||||
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
||||
} else {
|
||||
@@ -668,6 +711,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil || registration.SuspendedClients == nil {
|
||||
@@ -678,6 +722,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
||||
}
|
||||
delete(registration.SuspendedClients, clientID)
|
||||
registration.LastUpdated = time.Now()
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
||||
}
|
||||
|
||||
@@ -713,22 +758,51 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
|
||||
// Returns:
|
||||
// - []map[string]any: List of available models in the requested format
|
||||
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
now := time.Now()
|
||||
|
||||
models := make([]map[string]any, 0)
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
r.mutex.RLock()
|
||||
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||
models := cloneModelMaps(cache.models)
|
||||
r.mutex.RUnlock()
|
||||
return models
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||
return cloneModelMaps(cache.models)
|
||||
}
|
||||
|
||||
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
|
||||
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
|
||||
models: cloneModelMaps(models),
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
||||
models := make([]map[string]any, 0, len(r.models))
|
||||
var expiresAt time.Time
|
||||
|
||||
for _, registration := range r.models {
|
||||
// Check if model has any non-quota-exceeded clients
|
||||
availableClients := registration.Count
|
||||
now := time.Now()
|
||||
|
||||
// Count clients that have exceeded quota but haven't recovered yet
|
||||
expiredClients := 0
|
||||
for _, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime == nil {
|
||||
continue
|
||||
}
|
||||
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
|
||||
if now.Before(recoveryAt) {
|
||||
expiredClients++
|
||||
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
||||
expiresAt = recoveryAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -749,7 +823,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
effectiveClients = 0
|
||||
}
|
||||
|
||||
// Include models that have available clients, or those solely cooling down.
|
||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||
model := r.convertModelToMap(registration.Info, handlerType)
|
||||
if model != nil {
|
||||
@@ -758,7 +831,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
return models, expiresAt
|
||||
}
|
||||
|
||||
func cloneModelMaps(models []map[string]any) []map[string]any {
|
||||
cloned := make([]map[string]any, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
cloned = append(cloned, nil)
|
||||
continue
|
||||
}
|
||||
copyModel := make(map[string]any, len(model))
|
||||
for key, value := range model {
|
||||
copyModel[key] = cloneModelMapValue(value)
|
||||
}
|
||||
cloned = append(cloned, copyModel)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneModelMapValue(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
copyMap := make(map[string]any, len(typed))
|
||||
for key, entry := range typed {
|
||||
copyMap[key] = cloneModelMapValue(entry)
|
||||
}
|
||||
return copyMap
|
||||
case []any:
|
||||
copySlice := make([]any, len(typed))
|
||||
for i, entry := range typed {
|
||||
copySlice[i] = cloneModelMapValue(entry)
|
||||
}
|
||||
return copySlice
|
||||
case []string:
|
||||
return append([]string(nil), typed...)
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
||||
@@ -822,7 +932,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
return nil
|
||||
}
|
||||
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
now := time.Now()
|
||||
result := make([]*ModelInfo, 0, len(providerModels))
|
||||
|
||||
@@ -844,7 +953,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||
continue
|
||||
}
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
@@ -874,11 +983,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
|
||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||
if entry.info != nil {
|
||||
result = append(result, entry.info)
|
||||
result = append(result, cloneModelInfo(entry.info))
|
||||
continue
|
||||
}
|
||||
if ok && registration != nil && registration.Info != nil {
|
||||
result = append(result, registration.Info)
|
||||
result = append(result, cloneModelInfo(registration.Info))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -898,12 +1007,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
|
||||
// Count clients that have exceeded quota but haven't recovered yet
|
||||
expiredClients := 0
|
||||
for _, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
@@ -987,13 +1095,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
||||
if reg.Providers != nil {
|
||||
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||
return info
|
||||
return cloneModelInfo(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback to global info (last registered)
|
||||
return reg.Info
|
||||
return cloneModelInfo(reg.Info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1033,7 +1141,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
result["supported_parameters"] = model.SupportedParameters
|
||||
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if len(model.SupportedEndpoints) > 0 {
|
||||
result["supported_endpoints"] = model.SupportedEndpoints
|
||||
@@ -1094,13 +1202,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
result["outputTokenLimit"] = model.OutputTokenLimit
|
||||
}
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
||||
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedInputModalities) > 0 {
|
||||
result["supportedInputModalities"] = model.SupportedInputModalities
|
||||
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
|
||||
}
|
||||
if len(model.SupportedOutputModalities) > 0 {
|
||||
result["supportedOutputModalities"] = model.SupportedOutputModalities
|
||||
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -1129,16 +1237,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
invalidated := false
|
||||
|
||||
for modelID, registration := range r.models {
|
||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
invalidated = true
|
||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
if invalidated {
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||
@@ -1152,8 +1264,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
// - string: The model ID of the first available model, or empty string if none available
|
||||
// - error: An error if no models are available
|
||||
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// Get all available models for this handler type
|
||||
models := r.GetAvailableModels(handlerType)
|
||||
@@ -1213,13 +1323,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
||||
// Prefer client's own model info to preserve original type/owned_by
|
||||
if clientInfos != nil {
|
||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||
result = append(result, info)
|
||||
result = append(result, cloneModelInfo(info))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Fallback to global registry (for backwards compatibility)
|
||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||
result = append(result, reg.Info)
|
||||
result = append(result, cloneModelInfo(reg.Info))
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
54
internal/registry/model_registry_cache_test.go
Normal file
54
internal/registry/model_registry_cache_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package registry
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||
|
||||
first := r.GetAvailableModels("openai")
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(first))
|
||||
}
|
||||
first[0]["id"] = "mutated"
|
||||
first[0]["display_name"] = "Mutated"
|
||||
|
||||
second := r.GetAvailableModels("openai")
|
||||
if got := second[0]["id"]; got != "m1" {
|
||||
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
|
||||
}
|
||||
if got := second[0]["display_name"]; got != "Model One" {
|
||||
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||
|
||||
models := r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(models))
|
||||
}
|
||||
if got := models[0]["display_name"]; got != "Model One" {
|
||||
t.Fatalf("expected initial display_name Model One, got %v", got)
|
||||
}
|
||||
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
|
||||
models = r.GetAvailableModels("openai")
|
||||
if got := models[0]["display_name"]; got != "Model One Updated" {
|
||||
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
|
||||
}
|
||||
|
||||
r.SuspendClientModel("client-1", "m1", "manual")
|
||||
models = r.GetAvailableModels("openai")
|
||||
if len(models) != 0 {
|
||||
t.Fatalf("expected no available models after suspension, got %d", len(models))
|
||||
}
|
||||
|
||||
r.ResumeClientModel("client-1", "m1")
|
||||
models = r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected model to reappear after resume, got %d", len(models))
|
||||
}
|
||||
}
|
||||
149
internal/registry/model_registry_safety_test.go
Normal file
149
internal/registry/model_registry_safety_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetModelInfoReturnsClone(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetModelInfo("m1", "gemini")
|
||||
if first == nil {
|
||||
t.Fatal("expected model info")
|
||||
}
|
||||
first.DisplayName = "mutated"
|
||||
first.Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetModelInfo("m1", "gemini")
|
||||
if second.DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
|
||||
}
|
||||
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelsForClientReturnsClones(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetModelsForClient("client-1")
|
||||
if len(first) != 1 || first[0] == nil {
|
||||
t.Fatalf("expected one model, got %+v", first)
|
||||
}
|
||||
first[0].DisplayName = "mutated"
|
||||
first[0].Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetModelsForClient("client-1")
|
||||
if len(second) != 1 || second[0] == nil {
|
||||
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||
}
|
||||
if second[0].DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||
}
|
||||
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetAvailableModelsByProvider("gemini")
|
||||
if len(first) != 1 || first[0] == nil {
|
||||
t.Fatalf("expected one model, got %+v", first)
|
||||
}
|
||||
first[0].DisplayName = "mutated"
|
||||
first[0].Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetAvailableModelsByProvider("gemini")
|
||||
if len(second) != 1 || second[0] == nil {
|
||||
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||
}
|
||||
if second[0].DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||
}
|
||||
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
|
||||
r.SetModelQuotaExceeded("client-1", "m1")
|
||||
if models := r.GetAvailableModels("openai"); len(models) != 1 {
|
||||
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
|
||||
}
|
||||
|
||||
r.mutex.Lock()
|
||||
quotaTime := time.Now().Add(-6 * time.Minute)
|
||||
r.models["m1"].QuotaExceededClients["client-1"] = "aTime
|
||||
r.mutex.Unlock()
|
||||
|
||||
r.CleanupExpiredQuotas()
|
||||
|
||||
if count := r.GetModelCount("m1"); count != 1 {
|
||||
t.Fatalf("expected model count 1 after cleanup, got %d", count)
|
||||
}
|
||||
models := r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
|
||||
}
|
||||
if got := models[0]["id"]; got != "m1" {
|
||||
t.Fatalf("expected model id m1, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "openai", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
SupportedParameters: []string{"temperature", "top_p"},
|
||||
}})
|
||||
|
||||
first := r.GetAvailableModels("openai")
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("expected one model, got %d", len(first))
|
||||
}
|
||||
params, ok := first[0]["supported_parameters"].([]string)
|
||||
if !ok || len(params) != 2 {
|
||||
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
|
||||
}
|
||||
params[0] = "mutated"
|
||||
|
||||
second := r.GetAvailableModels("openai")
|
||||
params, ok = second[0]["supported_parameters"].([]string)
|
||||
if !ok || len(params) != 2 || params[0] != "temperature" {
|
||||
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
|
||||
first := LookupModelInfo("glm-4.6")
|
||||
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
|
||||
t.Fatalf("expected static model with thinking levels, got %+v", first)
|
||||
}
|
||||
first.Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := LookupModelInfo("glm-4.6")
|
||||
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
|
||||
t.Fatalf("expected static lookup clone, got %+v", second)
|
||||
}
|
||||
}
|
||||
372
internal/registry/model_updater.go
Normal file
372
internal/registry/model_updater.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
modelsFetchTimeout = 30 * time.Second
|
||||
modelsRefreshInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
var modelsURLs = []string{
|
||||
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
|
||||
"https://models.router-for.me/models.json",
|
||||
}
|
||||
|
||||
//go:embed models/models.json
|
||||
var embeddedModelsJSON []byte
|
||||
|
||||
type modelStore struct {
|
||||
mu sync.RWMutex
|
||||
data *staticModelsJSON
|
||||
}
|
||||
|
||||
var modelsCatalogStore = &modelStore{}
|
||||
|
||||
var updaterOnce sync.Once
|
||||
|
||||
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
|
||||
// changedProviders contains the provider names whose model definitions changed.
|
||||
type ModelRefreshCallback func(changedProviders []string)
|
||||
|
||||
var (
|
||||
refreshCallbackMu sync.Mutex
|
||||
refreshCallback ModelRefreshCallback
|
||||
pendingRefreshChanges []string
|
||||
)
|
||||
|
||||
// SetModelRefreshCallback registers a callback that is invoked when startup or
|
||||
// periodic model refresh detects changes. Only one callback is supported;
|
||||
// subsequent calls replace the previous callback.
|
||||
func SetModelRefreshCallback(cb ModelRefreshCallback) {
|
||||
refreshCallbackMu.Lock()
|
||||
refreshCallback = cb
|
||||
var pending []string
|
||||
if cb != nil && len(pendingRefreshChanges) > 0 {
|
||||
pending = append([]string(nil), pendingRefreshChanges...)
|
||||
pendingRefreshChanges = nil
|
||||
}
|
||||
refreshCallbackMu.Unlock()
|
||||
|
||||
if cb != nil && len(pending) > 0 {
|
||||
cb(pending)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Load embedded data as fallback on startup.
|
||||
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
|
||||
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// StartModelsUpdater starts a background updater that fetches models
|
||||
// immediately on startup and then refreshes the model catalog every 3 hours.
|
||||
// Safe to call multiple times; only one updater will run.
|
||||
func StartModelsUpdater(ctx context.Context) {
|
||||
updaterOnce.Do(func() {
|
||||
go runModelsUpdater(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func runModelsUpdater(ctx context.Context) {
|
||||
tryStartupRefresh(ctx)
|
||||
periodicRefresh(ctx)
|
||||
}
|
||||
|
||||
func periodicRefresh(ctx context.Context) {
|
||||
ticker := time.NewTicker(modelsRefreshInterval)
|
||||
defer ticker.Stop()
|
||||
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
tryPeriodicRefresh(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryPeriodicRefresh fetches models from remote, compares with the current
|
||||
// catalog, and notifies the registered callback if any provider changed.
|
||||
func tryPeriodicRefresh(ctx context.Context) {
|
||||
tryRefreshModels(ctx, "periodic model refresh")
|
||||
}
|
||||
|
||||
// tryStartupRefresh fetches models from remote in the background during
|
||||
// process startup. It uses the same change detection as periodic refresh so
|
||||
// existing auth registrations can be updated after the callback is registered.
|
||||
func tryStartupRefresh(ctx context.Context) {
|
||||
tryRefreshModels(ctx, "startup model refresh")
|
||||
}
|
||||
|
||||
func tryRefreshModels(ctx context.Context, label string) {
|
||||
oldData := getModels()
|
||||
|
||||
parsed, url := fetchModelsFromRemote(ctx)
|
||||
if parsed == nil {
|
||||
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
|
||||
return
|
||||
}
|
||||
|
||||
// Detect changes before updating store.
|
||||
changed := detectChangedProviders(oldData, parsed)
|
||||
|
||||
// Update store with new data regardless.
|
||||
modelsCatalogStore.mu.Lock()
|
||||
modelsCatalogStore.data = parsed
|
||||
modelsCatalogStore.mu.Unlock()
|
||||
|
||||
if len(changed) == 0 {
|
||||
log.Infof("%s completed from %s, no changes detected", label, url)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
|
||||
notifyModelRefresh(changed)
|
||||
}
|
||||
|
||||
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
|
||||
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
|
||||
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
|
||||
client := &http.Client{Timeout: modelsFetchTimeout}
|
||||
for _, url := range modelsURLs {
|
||||
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
|
||||
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
log.Debugf("models fetch request creation failed for %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
log.Debugf("models fetch failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
resp.Body.Close()
|
||||
cancel()
|
||||
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("models fetch read error from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var parsed staticModelsJSON
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
log.Warnf("models parse failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
if err := validateModelsCatalog(&parsed); err != nil {
|
||||
log.Warnf("models validate failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
return &parsed, url
|
||||
}
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// detectChangedProviders compares two model catalogs and returns provider names
|
||||
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
|
||||
// under a single "codex" provider.
|
||||
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
||||
if oldData == nil || newData == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type section struct {
|
||||
provider string
|
||||
oldList []*ModelInfo
|
||||
newList []*ModelInfo
|
||||
}
|
||||
|
||||
sections := []section{
|
||||
{"claude", oldData.Claude, newData.Claude},
|
||||
{"gemini", oldData.Gemini, newData.Gemini},
|
||||
{"vertex", oldData.Vertex, newData.Vertex},
|
||||
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
|
||||
{"aistudio", oldData.AIStudio, newData.AIStudio},
|
||||
{"codex", oldData.CodexFree, newData.CodexFree},
|
||||
{"codex", oldData.CodexTeam, newData.CodexTeam},
|
||||
{"codex", oldData.CodexPlus, newData.CodexPlus},
|
||||
{"codex", oldData.CodexPro, newData.CodexPro},
|
||||
{"qwen", oldData.Qwen, newData.Qwen},
|
||||
{"iflow", oldData.IFlow, newData.IFlow},
|
||||
{"kimi", oldData.Kimi, newData.Kimi},
|
||||
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(sections))
|
||||
var changed []string
|
||||
for _, s := range sections {
|
||||
if seen[s.provider] {
|
||||
continue
|
||||
}
|
||||
if modelSectionChanged(s.oldList, s.newList) {
|
||||
changed = append(changed, s.provider)
|
||||
seen[s.provider] = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// modelSectionChanged reports whether two model slices differ.
|
||||
func modelSectionChanged(a, b []*ModelInfo) bool {
|
||||
if len(a) != len(b) {
|
||||
return true
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return false
|
||||
}
|
||||
aj, err1 := json.Marshal(a)
|
||||
bj, err2 := json.Marshal(b)
|
||||
if err1 != nil || err2 != nil {
|
||||
return true
|
||||
}
|
||||
return string(aj) != string(bj)
|
||||
}
|
||||
|
||||
func notifyModelRefresh(changedProviders []string) {
|
||||
if len(changedProviders) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
refreshCallbackMu.Lock()
|
||||
cb := refreshCallback
|
||||
if cb == nil {
|
||||
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
|
||||
refreshCallbackMu.Unlock()
|
||||
return
|
||||
}
|
||||
refreshCallbackMu.Unlock()
|
||||
cb(changedProviders)
|
||||
}
|
||||
|
||||
func mergeProviderNames(existing, incoming []string) []string {
|
||||
if len(incoming) == 0 {
|
||||
return existing
|
||||
}
|
||||
seen := make(map[string]struct{}, len(existing)+len(incoming))
|
||||
merged := make([]string, 0, len(existing)+len(incoming))
|
||||
for _, provider := range existing {
|
||||
name := strings.ToLower(strings.TrimSpace(provider))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
for _, provider := range incoming {
|
||||
name := strings.ToLower(strings.TrimSpace(provider))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func loadModelsFromBytes(data []byte, source string) error {
|
||||
var parsed staticModelsJSON
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
return fmt.Errorf("%s: decode models catalog: %w", source, err)
|
||||
}
|
||||
if err := validateModelsCatalog(&parsed); err != nil {
|
||||
return fmt.Errorf("%s: validate models catalog: %w", source, err)
|
||||
}
|
||||
|
||||
modelsCatalogStore.mu.Lock()
|
||||
modelsCatalogStore.data = &parsed
|
||||
modelsCatalogStore.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func getModels() *staticModelsJSON {
|
||||
modelsCatalogStore.mu.RLock()
|
||||
defer modelsCatalogStore.mu.RUnlock()
|
||||
return modelsCatalogStore.data
|
||||
}
|
||||
|
||||
func validateModelsCatalog(data *staticModelsJSON) error {
|
||||
if data == nil {
|
||||
return fmt.Errorf("catalog is nil")
|
||||
}
|
||||
|
||||
requiredSections := []struct {
|
||||
name string
|
||||
models []*ModelInfo
|
||||
}{
|
||||
{name: "claude", models: data.Claude},
|
||||
{name: "gemini", models: data.Gemini},
|
||||
{name: "vertex", models: data.Vertex},
|
||||
{name: "gemini-cli", models: data.GeminiCLI},
|
||||
{name: "aistudio", models: data.AIStudio},
|
||||
{name: "codex-free", models: data.CodexFree},
|
||||
{name: "codex-team", models: data.CodexTeam},
|
||||
{name: "codex-plus", models: data.CodexPlus},
|
||||
{name: "codex-pro", models: data.CodexPro},
|
||||
{name: "qwen", models: data.Qwen},
|
||||
{name: "iflow", models: data.IFlow},
|
||||
{name: "kimi", models: data.Kimi},
|
||||
{name: "antigravity", models: data.Antigravity},
|
||||
}
|
||||
|
||||
for _, section := range requiredSections {
|
||||
if err := validateModelSection(section.name, section.models); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateModelSection(section string, models []*ModelInfo) error {
|
||||
if len(models) == 0 {
|
||||
return fmt.Errorf("%s section is empty", section)
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for i, model := range models {
|
||||
if model == nil {
|
||||
return fmt.Errorf("%s[%d] is null", section, i)
|
||||
}
|
||||
modelID := strings.TrimSpace(model.ID)
|
||||
if modelID == "" {
|
||||
return fmt.Errorf("%s[%d] has empty id", section, i)
|
||||
}
|
||||
if _, exists := seen[modelID]; exists {
|
||||
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
|
||||
}
|
||||
seen[modelID] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
2683
internal/registry/models/models.json
Normal file
2683
internal/registry/models/models.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -43,7 +42,6 @@ const (
|
||||
antigravityCountTokensPath = "/v1internal:countTokens"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
||||
@@ -55,78 +53,8 @@ const (
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||
antigravityPrimaryModelsCache struct {
|
||||
mu sync.RWMutex
|
||||
models []*registry.ModelInfo
|
||||
}
|
||||
)
|
||||
|
||||
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*registry.ModelInfo, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, cloneAntigravityModelInfo(model))
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||
if model == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *model
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
thinkingClone := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
clone.Thinking = &thinkingClone
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||
cloned := cloneAntigravityModels(models)
|
||||
if len(cloned) == 0 {
|
||||
return false
|
||||
}
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = cloned
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
antigravityPrimaryModelsCache.mu.RLock()
|
||||
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||
return cloned
|
||||
}
|
||||
|
||||
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
models := loadAntigravityPrimaryModels()
|
||||
if len(models) > 0 {
|
||||
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
type AntigravityExecutor struct {
|
||||
cfg *config.Config
|
||||
@@ -1150,168 +1078,6 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAntigravityModels retrieves available models using the supplied auth.
|
||||
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
|
||||
var payload []byte
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||
}
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
payload = []byte(`{}`)
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
if host := resolveHost(baseURL); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
modelConfig := registry.GetAntigravityModelConfig()
|
||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||
for originalName, modelData := range result.Map() {
|
||||
modelID := strings.TrimSpace(originalName)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
switch modelID {
|
||||
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||
continue
|
||||
}
|
||||
modelCfg := modelConfig[modelID]
|
||||
|
||||
// Extract displayName from upstream response, fallback to modelID
|
||||
displayName := modelData.Get("displayName").String()
|
||||
if displayName == "" {
|
||||
displayName = modelID
|
||||
}
|
||||
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: modelID,
|
||||
Name: modelID,
|
||||
Description: displayName,
|
||||
DisplayName: displayName,
|
||||
Version: modelID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: antigravityAuthType,
|
||||
Type: antigravityAuthType,
|
||||
}
|
||||
|
||||
// Build input modalities from upstream capability flags.
|
||||
inputModalities := []string{"TEXT"}
|
||||
if modelData.Get("supportsImages").Bool() {
|
||||
inputModalities = append(inputModalities, "IMAGE")
|
||||
}
|
||||
if modelData.Get("supportsVideo").Bool() {
|
||||
inputModalities = append(inputModalities, "VIDEO")
|
||||
}
|
||||
modelInfo.SupportedInputModalities = inputModalities
|
||||
modelInfo.SupportedOutputModalities = []string{"TEXT"}
|
||||
|
||||
// Token limits from upstream.
|
||||
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||
modelInfo.InputTokenLimit = int(maxTok)
|
||||
}
|
||||
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||
modelInfo.OutputTokenLimit = int(maxOut)
|
||||
}
|
||||
|
||||
// Supported generation methods (Gemini v1beta convention).
|
||||
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
|
||||
|
||||
// Look up Thinking support from static config using upstream model name.
|
||||
if modelCfg != nil {
|
||||
if modelCfg.Thinking != nil {
|
||||
modelInfo.Thinking = modelCfg.Thinking
|
||||
}
|
||||
if modelCfg.MaxCompletionTokens > 0 {
|
||||
modelInfo.MaxCompletionTokens = modelCfg.MaxCompletionTokens
|
||||
}
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
storeAntigravityPrimaryModels(models)
|
||||
return models
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = nil
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
seed := []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||
t.Fatal("expected non-empty model list to update primary cache")
|
||||
}
|
||||
|
||||
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||
}
|
||||
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||
ID: "gpt-5",
|
||||
DisplayName: "GPT-5",
|
||||
SupportedGenerationMethods: []string{"generateContent"},
|
||||
SupportedParameters: []string{"temperature"},
|
||||
Thinking: ®istry.ThinkingSupport{
|
||||
Levels: []string{"high"},
|
||||
},
|
||||
}}); !updated {
|
||||
t.Fatal("expected model cache update")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one cached model, got %d", len(got))
|
||||
}
|
||||
got[0].ID = "mutated-id"
|
||||
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||
}
|
||||
if len(got[0].SupportedParameters) > 0 {
|
||||
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||
}
|
||||
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||
got[0].Thinking.Levels[0] = "mutated-level"
|
||||
}
|
||||
|
||||
again := loadAntigravityPrimaryModels()
|
||||
if len(again) != 1 {
|
||||
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||
}
|
||||
if again[0].ID != "gpt-5" {
|
||||
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||
}
|
||||
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||
}
|
||||
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||
}
|
||||
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||
}
|
||||
}
|
||||
@@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if system.Type == gjson.String && system.String() != "" {
|
||||
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", system.String())
|
||||
result += "," + partJSON
|
||||
}
|
||||
result += "]"
|
||||
|
||||
@@ -1485,25 +1489,27 @@ func countCacheControlsMap(root map[string]any) int {
|
||||
return count
|
||||
}
|
||||
|
||||
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
|
||||
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
|
||||
ccRaw, exists := obj["cache_control"]
|
||||
if !exists {
|
||||
return
|
||||
return false
|
||||
}
|
||||
cc, ok := asObject(ccRaw)
|
||||
if !ok {
|
||||
*seen5m = true
|
||||
return
|
||||
return false
|
||||
}
|
||||
ttlRaw, ttlExists := cc["ttl"]
|
||||
ttl, ttlIsString := ttlRaw.(string)
|
||||
if !ttlExists || !ttlIsString || ttl != "1h" {
|
||||
*seen5m = true
|
||||
return
|
||||
return false
|
||||
}
|
||||
if *seen5m {
|
||||
delete(cc, "ttl")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func findLastCacheControlIndex(arr []any) int {
|
||||
@@ -1599,11 +1605,14 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
}
|
||||
|
||||
seen5m := false
|
||||
modified := false
|
||||
|
||||
if tools, ok := asArray(root["tools"]); ok {
|
||||
for _, tool := range tools {
|
||||
if obj, ok := asObject(tool); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1611,7 +1620,9 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
if system, ok := asArray(root["system"]); ok {
|
||||
for _, item := range system {
|
||||
if obj, ok := asObject(item); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1628,12 +1639,17 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
}
|
||||
for _, item := range content {
|
||||
if obj, ok := asObject(item); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return payload
|
||||
}
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
|
||||
@@ -369,6 +369,19 @@ func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
|
||||
// Payload where no TTL normalization is needed (all blocks use 1h with no
|
||||
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
|
||||
// that json.Marshal would escape to \u003c etc., altering byte identity.
|
||||
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"<system-reminder>foo & bar</system-reminder>","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
|
||||
out := normalizeCacheControlTTL(payload)
|
||||
|
||||
if !bytes.Equal(out, payload) {
|
||||
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
@@ -829,8 +842,8 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
// Inject Accept-Encoding via the custom header attribute mechanism.
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
@@ -967,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
|
||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 1: String system prompt is preserved and converted to a content block
|
||||
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
|
||||
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
system := gjson.GetBytes(out, "system")
|
||||
if !system.IsArray() {
|
||||
t.Fatalf("system should be an array, got %s", system.Type)
|
||||
}
|
||||
|
||||
blocks := system.Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
|
||||
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
|
||||
}
|
||||
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
|
||||
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
|
||||
}
|
||||
if blocks[2].Get("text").String() != "You are a helpful assistant." {
|
||||
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
|
||||
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 2: Strict mode drops the string system prompt
|
||||
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
|
||||
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, true)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 2 {
|
||||
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 3: Empty string system prompt does not produce a spurious block
|
||||
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
|
||||
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 2 {
|
||||
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 4: Array system prompt is unaffected by the string handling
|
||||
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
|
||||
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
if blocks[2].Get("text").String() != "Be concise." {
|
||||
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 5: Special characters in string system prompt survive conversion
|
||||
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
|
||||
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
|
||||
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true)
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -226,7 +226,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, false)
|
||||
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -321,7 +321,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true)
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -636,7 +636,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool) {
|
||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
@@ -647,7 +647,8 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -31,7 +32,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04"
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
|
||||
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
|
||||
codexResponsesWebsocketHandshakeTO = 30 * time.Second
|
||||
)
|
||||
@@ -57,11 +58,6 @@ type codexWebsocketSession struct {
|
||||
wsURL string
|
||||
authID string
|
||||
|
||||
// connCreateSent tracks whether a `response.create` message has been successfully sent
|
||||
// on the current websocket connection. The upstream expects the first message on each
|
||||
// connection to be `response.create`.
|
||||
connCreateSent bool
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
activeMu sync.Mutex
|
||||
@@ -195,7 +191,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -212,13 +208,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
defer sess.reqMu.Unlock()
|
||||
}
|
||||
|
||||
allowAppend := true
|
||||
if sess != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -280,10 +270,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
// execution session.
|
||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if errDialRetry == nil && connRetry != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -312,7 +299,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
return resp, errSend
|
||||
}
|
||||
}
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
for {
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
@@ -400,29 +386,23 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
|
||||
executionSessionID := executionSessionIDFromOptions(opts)
|
||||
var sess *codexWebsocketSession
|
||||
if executionSessionID != "" {
|
||||
sess = e.getOrCreateSession(executionSessionID)
|
||||
sess.reqMu.Lock()
|
||||
if sess != nil {
|
||||
sess.reqMu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
allowAppend := true
|
||||
if sess != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -483,10 +463,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
sess.reqMu.Unlock()
|
||||
return nil, errDialRetry
|
||||
}
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -515,7 +492,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
return nil, errSend
|
||||
}
|
||||
}
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
@@ -657,31 +633,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con
|
||||
return conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
|
||||
func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte {
|
||||
func buildCodexWebsocketRequestBody(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns.
|
||||
// The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation).
|
||||
// Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive.
|
||||
//
|
||||
// NOTE: The upstream expects the first websocket event on each connection to be `response.create`,
|
||||
// so we only use `response.append` after we have initialized the current connection.
|
||||
if allowAppend {
|
||||
if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" {
|
||||
inputNode := gjson.GetBytes(body, "input")
|
||||
wsReqBody := []byte(`{}`)
|
||||
wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append")
|
||||
if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" {
|
||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw))
|
||||
return wsReqBody
|
||||
}
|
||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]"))
|
||||
return wsReqBody
|
||||
}
|
||||
}
|
||||
|
||||
// Match codex-rs websocket v2 semantics: every request is `response.create`.
|
||||
// Incremental follow-up turns continue on the same websocket using
|
||||
// `previous_response_id` + incremental `input`, not `response.append`.
|
||||
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
|
||||
if errSet == nil && len(wsReqBody) > 0 {
|
||||
return wsReqBody
|
||||
@@ -725,21 +684,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession,
|
||||
}
|
||||
}
|
||||
|
||||
func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) {
|
||||
if sess == nil || conn == nil || len(payload) == 0 {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||||
return
|
||||
}
|
||||
|
||||
sess.connMu.Lock()
|
||||
if sess.conn == conn {
|
||||
sess.connCreateSent = true
|
||||
}
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
|
||||
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -762,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
||||
return dialer
|
||||
}
|
||||
|
||||
parsedURL, errParse := url.Parse(proxyURL)
|
||||
setting, errParse := proxyutil.Parse(proxyURL)
|
||||
if errParse != nil {
|
||||
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse)
|
||||
log.Errorf("codex websockets executor: %v", errParse)
|
||||
return dialer
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
switch setting.Mode {
|
||||
case proxyutil.ModeDirect:
|
||||
dialer.Proxy = nil
|
||||
return dialer
|
||||
case proxyutil.ModeProxy:
|
||||
default:
|
||||
return dialer
|
||||
}
|
||||
|
||||
switch setting.URL.Scheme {
|
||||
case "socks5":
|
||||
var proxyAuth *proxy.Auth
|
||||
if parsedURL.User != nil {
|
||||
username := parsedURL.User.Username()
|
||||
password, _ := parsedURL.User.Password()
|
||||
if setting.URL.User != nil {
|
||||
username := setting.URL.User.Username()
|
||||
password, _ := setting.URL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
|
||||
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return dialer
|
||||
@@ -786,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
||||
return socksDialer.Dial(network, addr)
|
||||
}
|
||||
case "http", "https":
|
||||
dialer.Proxy = http.ProxyURL(parsedURL)
|
||||
dialer.Proxy = http.ProxyURL(setting.URL)
|
||||
default:
|
||||
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
|
||||
}
|
||||
|
||||
return dialer
|
||||
@@ -844,7 +797,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
||||
return rawJSON, headers
|
||||
}
|
||||
|
||||
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string) http.Header {
|
||||
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
|
||||
if headers == nil {
|
||||
headers = http.Header{}
|
||||
}
|
||||
@@ -857,7 +810,8 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-beta-features", "")
|
||||
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
|
||||
@@ -872,7 +826,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
}
|
||||
headers.Set("OpenAI-Beta", betaHeader)
|
||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(headers, ginHeaders, "User-Agent", codexUserAgent)
|
||||
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
isAPIKey := false
|
||||
if auth != nil && auth.Attributes != nil {
|
||||
@@ -900,6 +854,62 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
return headers
|
||||
}
|
||||
|
||||
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
|
||||
if cfg == nil || auth == nil {
|
||||
return "", ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||
}
|
||||
|
||||
func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
||||
if target == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(target.Get(key)) != "" {
|
||||
return
|
||||
}
|
||||
if source != nil {
|
||||
if val := strings.TrimSpace(source.Get(key)); val != "" {
|
||||
target.Set(key, val)
|
||||
return
|
||||
}
|
||||
}
|
||||
if val := strings.TrimSpace(configValue); val != "" {
|
||||
target.Set(key, val)
|
||||
return
|
||||
}
|
||||
if val := strings.TrimSpace(fallbackValue); val != "" {
|
||||
target.Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
||||
if target == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(target.Get(key)) != "" {
|
||||
return
|
||||
}
|
||||
if val := strings.TrimSpace(configValue); val != "" {
|
||||
target.Set(key, val)
|
||||
return
|
||||
}
|
||||
if source != nil {
|
||||
if val := strings.TrimSpace(source.Get(key)); val != "" {
|
||||
target.Set(key, val)
|
||||
return
|
||||
}
|
||||
}
|
||||
if val := strings.TrimSpace(fallbackValue); val != "" {
|
||||
target.Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
type statusErrWithHeaders struct {
|
||||
statusErr
|
||||
headers http.Header
|
||||
@@ -1017,36 +1027,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
|
||||
}
|
||||
}
|
||||
|
||||
func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
||||
done := make(chan struct{})
|
||||
if ctx == nil || conn == nil {
|
||||
return done
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
||||
done := make(chan struct{})
|
||||
if ctx == nil || conn == nil {
|
||||
return done
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
_ = conn.SetReadDeadline(time.Now())
|
||||
}
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
|
||||
if len(opts.Metadata) == 0 {
|
||||
return ""
|
||||
@@ -1120,7 +1100,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
|
||||
sess.conn = conn
|
||||
sess.wsURL = wsURL
|
||||
sess.authID = authID
|
||||
sess.connCreateSent = false
|
||||
sess.readerConn = conn
|
||||
sess.connMu.Unlock()
|
||||
|
||||
@@ -1206,7 +1185,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
|
||||
return
|
||||
}
|
||||
sess.conn = nil
|
||||
sess.connCreateSent = false
|
||||
if sess.readerConn == conn {
|
||||
sess.readerConn = nil
|
||||
}
|
||||
@@ -1273,7 +1251,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
|
||||
authID := sess.authID
|
||||
wsURL := sess.wsURL
|
||||
sess.conn = nil
|
||||
sess.connCreateSent = false
|
||||
if sess.readerConn == conn {
|
||||
sess.readerConn = nil
|
||||
}
|
||||
|
||||
203
internal/runtime/executor/codex_websockets_executor_test.go
Normal file
203
internal/runtime/executor/codex_websockets_executor_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
|
||||
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
|
||||
t.Fatalf("type = %s, want response.create", got)
|
||||
}
|
||||
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
|
||||
t.Fatalf("previous_response_id = %s, want resp-1", got)
|
||||
}
|
||||
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
|
||||
t.Fatalf("input item id mismatch")
|
||||
}
|
||||
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
|
||||
t.Fatalf("unexpected websocket request type: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "my-codex-client/1.0",
|
||||
BetaFeatures: "feature-a,feature-b",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
|
||||
}
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "config-ua",
|
||||
BetaFeatures: "config-beta",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
ctx := contextWithGinHeaders(map[string]string{
|
||||
"User-Agent": "client-ua",
|
||||
"X-Codex-Beta-Features": "client-beta",
|
||||
})
|
||||
headers := http.Header{}
|
||||
headers.Set("User-Agent", "existing-ua")
|
||||
headers.Set("X-Codex-Beta-Features", "existing-beta")
|
||||
|
||||
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
|
||||
|
||||
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
|
||||
}
|
||||
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "config-ua",
|
||||
BetaFeatures: "config-beta",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
ctx := contextWithGinHeaders(map[string]string{
|
||||
"User-Agent": "client-ua",
|
||||
"X-Codex-Beta-Features": "client-beta",
|
||||
})
|
||||
|
||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "config-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "config-ua",
|
||||
BetaFeatures: "config-beta",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Attributes: map[string]string{"api_key": "sk-test"},
|
||||
}
|
||||
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest() error = %v", err)
|
||||
}
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "config-ua",
|
||||
BetaFeatures: "config-beta",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
req = req.WithContext(contextWithGinHeaders(map[string]string{
|
||||
"User-Agent": "client-ua",
|
||||
}))
|
||||
|
||||
applyCodexHeaders(req, auth, "oauth-token", true, cfg)
|
||||
|
||||
if got := req.Header.Get("User-Agent"); got != "config-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
||||
}
|
||||
if got := req.Header.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func contextWithGinHeaders(headers map[string]string) context.Context {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginCtx.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
ginCtx.Request.Header = make(http.Header, len(headers))
|
||||
for key, value := range headers {
|
||||
ginCtx.Request.Header.Set(key, value)
|
||||
}
|
||||
return context.WithValue(context.Background(), "gin", ginCtx)
|
||||
}
|
||||
|
||||
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dialer := newProxyAwareWebsocketDialer(
|
||||
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||
)
|
||||
|
||||
if dialer.Proxy != nil {
|
||||
t.Fatal("expected websocket proxy function to be nil for direct mode")
|
||||
}
|
||||
}
|
||||
@@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
@@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
action := getVertexAction(baseModel, true)
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||
// Imagen models don't support streaming, skip SSE params
|
||||
@@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
|
||||
|
||||
|
||||
@@ -522,9 +522,9 @@ func detectLastConversationRole(body []byte) string {
|
||||
}
|
||||
|
||||
switch item.Get("type").String() {
|
||||
case "function_call", "function_call_arguments":
|
||||
case "function_call", "function_call_arguments", "computer_call":
|
||||
return "assistant"
|
||||
case "function_call_output", "function_call_response", "tool_result":
|
||||
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
|
||||
return "tool"
|
||||
}
|
||||
}
|
||||
@@ -653,6 +653,7 @@ func normalizeGitHubCopilotChatTools(body []byte) []byte {
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
||||
body = stripGitHubCopilotResponsesUnsupportedFields(body)
|
||||
input := gjson.GetBytes(body, "input")
|
||||
if input.Exists() {
|
||||
// If input is already a string or array, keep it as-is.
|
||||
@@ -825,6 +826,12 @@ func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
|
||||
// GitHub Copilot /responses rejects service_tier, so always remove it.
|
||||
body, _ = sjson.DeleteBytes(body, "service_tier")
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
@@ -832,6 +839,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
if isGitHubCopilotResponsesBuiltinTool(toolType) {
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
|
||||
continue
|
||||
}
|
||||
// Accept OpenAI format (type="function") and Claude format
|
||||
// (no type field, but has top-level name + input_schema).
|
||||
if toolType != "" && toolType != "function" {
|
||||
@@ -879,6 +890,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
}
|
||||
if toolChoice.Type == gjson.JSON {
|
||||
choiceType := toolChoice.Get("type").String()
|
||||
if isGitHubCopilotResponsesBuiltinTool(choiceType) {
|
||||
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(toolChoice.Raw))
|
||||
return body
|
||||
}
|
||||
if choiceType == "function" {
|
||||
name := toolChoice.Get("name").String()
|
||||
if name == "" {
|
||||
@@ -896,6 +911,15 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
func isGitHubCopilotResponsesBuiltinTool(toolType string) bool {
|
||||
switch strings.TrimSpace(toolType) {
|
||||
case "computer", "computer_use_preview":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func collectTextFromNode(node gjson.Result) string {
|
||||
if !node.Exists() {
|
||||
return ""
|
||||
|
||||
@@ -132,6 +132,19 @@ func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_StripsServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"user text","service_tier":"default"}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
|
||||
if gjson.GetBytes(got, "service_tier").Exists() {
|
||||
t.Fatalf("service_tier should be removed, got %s", gjson.GetBytes(got, "service_tier").Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "input").String() != "user text" {
|
||||
t.Fatalf("input = %q, want %q", gjson.GetBytes(got, "input").String(), "user text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||
|
||||
1320
internal/runtime/executor/gitlab_executor.go
Normal file
1320
internal/runtime/executor/gitlab_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
469
internal/runtime/executor/gitlab_executor_test.go
Normal file
469
internal/runtime/executor/gitlab_executor_test.go
Normal file
@@ -0,0 +1,469 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitLabExecutorExecuteUsesChatEndpoint(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != gitLabChatEndpoint {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
_, _ = w.Write([]byte(`"chat response"`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "chat response" {
|
||||
t.Fatalf("expected chat response, got %q", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected resolved model, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteFallsBackToCodeSuggestions(t *testing.T) {
|
||||
chatCalls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case gitLabChatEndpoint:
|
||||
chatCalls++
|
||||
http.Error(w, "feature unavailable", http.StatusForbidden)
|
||||
case gitLabCodeSuggestionsEndpoint:
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"text": "fallback response",
|
||||
}},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"personal_access_token": "glpat-token",
|
||||
"auth_method": "pat",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"write code"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if chatCalls != 1 {
|
||||
t.Fatalf("expected chat endpoint to be tried once, got %d", chatCalls)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "fallback response" {
|
||||
t.Fatalf("expected fallback response, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesAnthropicGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{
|
||||
"model":"gitlab-duo",
|
||||
"messages":[{"role":"user","content":[{"type":"text","text":"list files"}]}],
|
||||
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}],
|
||||
"max_tokens":128
|
||||
}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotModel != "claude-sonnet-4-5" {
|
||||
t.Fatalf("model = %q, want claude-sonnet-4-5", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "content.0.type").String(); got != "tool_use" {
|
||||
t.Fatalf("expected tool_use response, got %q", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "content.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("expected tool name Bash, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from openai gateway\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from openai gateway\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-5-codex",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotModel != "gpt-5-codex" {
|
||||
t.Fatalf("model = %q, want gpt-5-codex", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from openai gateway" {
|
||||
t.Fatalf("expected openai gateway response, got %q payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/oauth/token":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "oauth-refreshed",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"token_type": "Bearer",
|
||||
"scope": "api read_user",
|
||||
"created_at": 1710000000,
|
||||
"expires_in": 3600,
|
||||
})
|
||||
case "/api/v4/code_suggestions/direct_access":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1710003600,
|
||||
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "gitlab-auth.json",
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"oauth_client_secret": "client-secret",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := exec.Refresh(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("Refresh() error = %v", err)
|
||||
}
|
||||
if got := updated.Metadata["access_token"]; got != "oauth-refreshed" {
|
||||
t.Fatalf("expected refreshed access token, got %#v", got)
|
||||
}
|
||||
if got := updated.Metadata["model_name"]; got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected refreshed model metadata, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesCodeSuggestionsSSE(t *testing.T) {
|
||||
var gotAccept, gotStreamingHeader, gotEncoding string
|
||||
var gotStreamFlag bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != gitLabCodeSuggestionsEndpoint {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
gotStreamingHeader = r.Header.Get(gitLabSSEStreamingHeader)
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
gotStreamFlag = gjson.GetBytes(readBody(t, r), "stream").Bool()
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: stream_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"model\":{\"name\":\"claude-sonnet-4-5\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_chunk\n"))
|
||||
_, _ = w.Write([]byte("data: {\"content\":\"hello\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_chunk\n"))
|
||||
_, _ = w.Write([]byte("data: {\"content\":\" world\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: stream_end\n"))
|
||||
_, _ = w.Write([]byte("data: {}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if gotAccept != "text/event-stream" {
|
||||
t.Fatalf("Accept = %q, want text/event-stream", gotAccept)
|
||||
}
|
||||
if gotStreamingHeader != "true" {
|
||||
t.Fatalf("%s = %q, want true", gitLabSSEStreamingHeader, gotStreamingHeader)
|
||||
}
|
||||
if gotEncoding != "identity" {
|
||||
t.Fatalf("Accept-Encoding = %q, want identity", gotEncoding)
|
||||
}
|
||||
if !gotStreamFlag {
|
||||
t.Fatalf("expected upstream request to set stream=true")
|
||||
}
|
||||
if len(lines) < 4 {
|
||||
t.Fatalf("expected translated stream chunks, got %d", len(lines))
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":"hello"`) {
|
||||
t.Fatalf("expected hello delta in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":" world"`) {
|
||||
t.Fatalf("expected world delta in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
last := lines[len(lines)-1]
|
||||
if last != "data: [DONE]" && !strings.Contains(last, `"finish_reason":"stop"`) {
|
||||
t.Fatalf("expected stream terminator, got %q", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
|
||||
chatCalls := 0
|
||||
streamCalls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case gitLabCodeSuggestionsEndpoint:
|
||||
streamCalls++
|
||||
http.Error(w, "feature unavailable", http.StatusForbidden)
|
||||
case gitLabChatEndpoint:
|
||||
chatCalls++
|
||||
_, _ = w.Write([]byte(`"chat fallback response"`))
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if streamCalls != 1 {
|
||||
t.Fatalf("expected streaming endpoint once, got %d", streamCalls)
|
||||
}
|
||||
if chatCalls != 1 {
|
||||
t.Fatalf("expected chat fallback once, got %d", chatCalls)
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":"chat fallback response"`) {
|
||||
t.Fatalf("expected fallback content in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
var gotPath string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: message_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from gateway\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_stop\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"max_tokens":64}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
||||
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func collectStreamLines(t *testing.T, result *cliproxyexecutor.StreamResult) []string {
|
||||
t.Helper()
|
||||
lines := make([]string, 0, 8)
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||
}
|
||||
lines = append(lines, string(chunk.Payload))
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func readBody(t *testing.T, r *http.Request) []byte {
|
||||
t.Helper()
|
||||
defer func() { _ = r.Body.Close() }()
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error = %v", err)
|
||||
}
|
||||
return body
|
||||
}
|
||||
@@ -2,17 +2,15 @@ package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
|
||||
@@ -111,45 +109,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
// Returns:
|
||||
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
|
||||
func buildProxyTransport(proxyURL string) *http.Transport {
|
||||
if proxyURL == "" {
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("%v", errBuild)
|
||||
return nil
|
||||
}
|
||||
|
||||
parsedURL, errParse := url.Parse(proxyURL)
|
||||
if errParse != nil {
|
||||
log.Errorf("parse proxy URL failed: %v", errParse)
|
||||
return nil
|
||||
}
|
||||
|
||||
var transport *http.Transport
|
||||
|
||||
// Handle different proxy schemes
|
||||
if parsedURL.Scheme == "socks5" {
|
||||
// Configure SOCKS5 proxy with optional authentication
|
||||
var proxyAuth *proxy.Auth
|
||||
if parsedURL.User != nil {
|
||||
username := parsedURL.User.Username()
|
||||
password, _ := parsedURL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return nil
|
||||
}
|
||||
// Set up a custom transport using the SOCKS5 dialer
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
|
||||
// Configure HTTP or HTTPS proxy
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
|
||||
} else {
|
||||
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
return nil
|
||||
}
|
||||
|
||||
return transport
|
||||
}
|
||||
|
||||
30
internal/runtime/executor/proxy_helpers_test.go
Normal file
30
internal/runtime/executor/proxy_helpers_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := newProxyAwareHTTPClient(
|
||||
context.Background(),
|
||||
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||
0,
|
||||
)
|
||||
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
|
||||
}
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("expected direct transport to disable proxy function")
|
||||
}
|
||||
}
|
||||
@@ -257,7 +257,10 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma
|
||||
if suffixResult.HasSuffix {
|
||||
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
|
||||
} else {
|
||||
config = extractThinkingConfig(body, toFormat)
|
||||
config = extractThinkingConfig(body, fromFormat)
|
||||
if !hasThinkingConfig(config) && fromFormat != toFormat {
|
||||
config = extractThinkingConfig(body, toFormat)
|
||||
}
|
||||
}
|
||||
|
||||
if !hasThinkingConfig(config) {
|
||||
@@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
|
||||
if config.Mode != ModeLevel {
|
||||
return config
|
||||
}
|
||||
if toFormat == "claude" {
|
||||
return config
|
||||
}
|
||||
if !isBudgetCapableProvider(toFormat) {
|
||||
return config
|
||||
}
|
||||
|
||||
55
internal/thinking/apply_user_defined_test.go
Normal file
55
internal/thinking/apply_user_defined_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package thinking_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-user-defined-claude-" + t.Name()
|
||||
modelID := "custom-claude-4-6"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(clientID)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "claude adaptive effort body",
|
||||
model: modelID,
|
||||
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
|
||||
},
|
||||
{
|
||||
name: "suffix level",
|
||||
model: modelID + "(high)",
|
||||
body: []byte(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyThinking() error = %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
|
||||
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
|
||||
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
|
||||
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -477,9 +477,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
if effort == "max" {
|
||||
effort = "high"
|
||||
}
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
|
||||
@@ -1235,64 +1235,3 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t
|
||||
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
effort string
|
||||
expected string
|
||||
}{
|
||||
{"low", "low", "low"},
|
||||
{"medium", "medium", "medium"},
|
||||
{"high", "high", "high"},
|
||||
{"max", "max", "high"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"thinking": {"type": "adaptive"},
|
||||
"output_config": {"effort": "` + tt.effort + `"}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
||||
if !thinkingConfig.Exists() {
|
||||
t.Fatal("thinkingConfig should exist for adaptive thinking")
|
||||
}
|
||||
if thinkingConfig.Get("thinkingLevel").String() != tt.expected {
|
||||
t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String())
|
||||
}
|
||||
if !thinkingConfig.Get("includeThoughts").Bool() {
|
||||
t.Error("includeThoughts should be true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"thinking": {"type": "adaptive"}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
||||
if !thinkingConfig.Exists() {
|
||||
t.Fatal("thinkingConfig should exist for adaptive thinking without effort")
|
||||
}
|
||||
if thinkingConfig.Get("thinkingLevel").String() != "high" {
|
||||
t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String())
|
||||
}
|
||||
if !thinkingConfig.Get("includeThoughts").Bool() {
|
||||
t.Error("includeThoughts should be true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -256,7 +257,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
|
||||
@@ -138,20 +138,31 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// FunctionCallGroup represents a group of function calls and their responses
|
||||
type FunctionCallGroup struct {
|
||||
ResponsesNeeded int
|
||||
CallNames []string // ordered function call names for backfilling empty response names
|
||||
}
|
||||
|
||||
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
|
||||
// Falls back to a minimal "functionResponse" object when parsing fails.
|
||||
func parseFunctionResponseRaw(response gjson.Result) string {
|
||||
// fallbackName is used when the response's own name is empty.
|
||||
func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string {
|
||||
if response.IsObject() && gjson.Valid(response.Raw) {
|
||||
return response.Raw
|
||||
raw := response.Raw
|
||||
name := response.Get("functionResponse.name").String()
|
||||
if strings.TrimSpace(name) == "" && fallbackName != "" {
|
||||
raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName)
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
log.Debugf("parse function response failed, using fallback")
|
||||
funcResp := response.Get("functionResponse")
|
||||
if funcResp.Exists() {
|
||||
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
|
||||
name := funcResp.Get("name").String()
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = fallbackName
|
||||
}
|
||||
fr, _ = sjson.Set(fr, "functionResponse.name", name)
|
||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
||||
if id := funcResp.Get("id").String(); id != "" {
|
||||
fr, _ = sjson.Set(fr, "functionResponse.id", id)
|
||||
@@ -159,7 +170,12 @@ func parseFunctionResponseRaw(response gjson.Result) string {
|
||||
return fr
|
||||
}
|
||||
|
||||
fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
|
||||
useName := fallbackName
|
||||
if useName == "" {
|
||||
useName = "unknown"
|
||||
}
|
||||
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
fr, _ = sjson.Set(fr, "functionResponse.name", useName)
|
||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
|
||||
return fr
|
||||
}
|
||||
@@ -211,30 +227,26 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
if len(responsePartsInThisContent) > 0 {
|
||||
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||
|
||||
// Check if any pending groups can be satisfied
|
||||
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||
group := pendingGroups[i]
|
||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
// Check if pending groups can be satisfied (FIFO: oldest group first)
|
||||
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
|
||||
group := pendingGroups[0]
|
||||
pendingGroups = pendingGroups[1:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response)
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
}
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for ri, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
}
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
|
||||
// Remove this group as it's been satisfied
|
||||
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||
break
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,15 +255,15 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
// If this is a model with function calls, create a new group
|
||||
if role == "model" {
|
||||
functionCallsCount := 0
|
||||
var callNames []string
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
functionCallsCount++
|
||||
callNames = append(callNames, part.Get("functionCall.name").String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if functionCallsCount > 0 {
|
||||
if len(callNames) > 0 {
|
||||
// Add the model content
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse model content")
|
||||
@@ -261,7 +273,8 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
ResponsesNeeded: functionCallsCount,
|
||||
ResponsesNeeded: len(callNames),
|
||||
CallNames: callNames,
|
||||
}
|
||||
pendingGroups = append(pendingGroups, group)
|
||||
} else {
|
||||
@@ -291,8 +304,8 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response)
|
||||
for ri, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
}
|
||||
|
||||
@@ -171,3 +171,257 @@ func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
||||
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) {
|
||||
// When the Amp client sends functionResponse with an empty name,
|
||||
// fixCLIToolResponse should backfill it from the corresponding functionCall.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
name := funcContent.Get("parts.0.functionResponse.name").String()
|
||||
if name != "Bash" {
|
||||
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_BackfillsMultipleEmptyNames(t *testing.T) {
|
||||
// Parallel function calls: both responses have empty names.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
|
||||
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
|
||||
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
parts := funcContent.Get("parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 function response parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
name0 := parts[0].Get("functionResponse.name").String()
|
||||
name1 := parts[1].Get("functionResponse.name").String()
|
||||
if name0 != "Read" {
|
||||
t.Errorf("Expected first response name 'Read', got '%s'", name0)
|
||||
}
|
||||
if name1 != "Grep" {
|
||||
t.Errorf("Expected second response name 'Grep', got '%s'", name1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_PreservesExistingName(t *testing.T) {
|
||||
// When functionResponse already has a valid name, it should be preserved.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
name := funcContent.Get("parts.0.functionResponse.name").String()
|
||||
if name != "Bash" {
|
||||
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_MoreResponsesThanCalls(t *testing.T) {
|
||||
// If there are more function responses than calls, unmatched extras are discarded by grouping.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
|
||||
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
// First response should be backfilled from the call
|
||||
name0 := funcContent.Get("parts.0.functionResponse.name").String()
|
||||
if name0 != "Bash" {
|
||||
t.Errorf("Expected first response name 'Bash', got '%s'", name0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_MultipleGroupsFIFO(t *testing.T) {
|
||||
// Two sequential function call groups should be matched FIFO.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Read", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "file content"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Grep", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "match"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContents []gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContents = append(funcContents, c)
|
||||
}
|
||||
}
|
||||
if len(funcContents) != 2 {
|
||||
t.Fatalf("Expected 2 function contents, got %d", len(funcContents))
|
||||
}
|
||||
|
||||
name0 := funcContents[0].Get("parts.0.functionResponse.name").String()
|
||||
name1 := funcContents[1].Get("parts.0.functionResponse.name").String()
|
||||
if name0 != "Read" {
|
||||
t.Errorf("Expected first group name 'Read', got '%s'", name0)
|
||||
}
|
||||
if name1 != "Grep" {
|
||||
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,6 +212,33 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else {
|
||||
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
|
||||
}
|
||||
case "input_audio":
|
||||
audioData := item.Get("input_audio.data").String()
|
||||
audioFormat := item.Get("input_audio.format").String()
|
||||
if audioData != "" {
|
||||
audioMimeMap := map[string]string{
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"webm": "audio/webm",
|
||||
"pcm16": "audio/pcm",
|
||||
"g711_ulaw": "audio/basic",
|
||||
"g711_alaw": "audio/basic",
|
||||
}
|
||||
mimeType := "audio/wav"
|
||||
if audioFormat != "" {
|
||||
if mapped, ok := audioMimeMap[audioFormat]; ok {
|
||||
mimeType = mapped
|
||||
} else {
|
||||
mimeType = "audio/" + audioFormat
|
||||
}
|
||||
}
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData)
|
||||
p++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,46 +203,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
} else if contentResult.Exists() && contentResult.IsArray() {
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "text":
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
|
||||
|
||||
case "image_url":
|
||||
// Convert OpenAI image format to Claude Code format
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Extract base64 data and media type from data URL
|
||||
parts := strings.Split(imageURL, ",")
|
||||
if len(parts) == 2 {
|
||||
mediaTypePart := strings.Split(parts[0], ";")[0]
|
||||
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
||||
data := parts[1]
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.Set(imagePart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||
}
|
||||
}
|
||||
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||
if claudePart != "" {
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -291,11 +254,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
case "tool":
|
||||
// Handle tool result messages conversion
|
||||
toolCallID := message.Get("tool_call_id").String()
|
||||
content := message.Get("content").String()
|
||||
toolContentResult := message.Get("content")
|
||||
|
||||
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
|
||||
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
|
||||
msg, _ = sjson.Set(msg, "content.0.content", content)
|
||||
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
|
||||
if toolResultContentRaw {
|
||||
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
|
||||
} else {
|
||||
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
messageIndex++
|
||||
}
|
||||
@@ -358,3 +326,110 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
|
||||
switch part.Get("type").String() {
|
||||
case "text":
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
return textPart
|
||||
|
||||
case "image_url":
|
||||
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
return docPart
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func convertOpenAIImageURLToClaudePart(imageURL string) string {
|
||||
if imageURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
parts := strings.SplitN(imageURL, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
mediaTypePart := strings.SplitN(parts[0], ";", 2)[0]
|
||||
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
||||
if mediaType == "" {
|
||||
mediaType = "application/octet-stream"
|
||||
}
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
|
||||
return imagePart
|
||||
}
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
|
||||
return imagePart
|
||||
}
|
||||
|
||||
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
|
||||
if !content.Exists() {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if content.Type == gjson.String {
|
||||
return content.String(), false
|
||||
}
|
||||
|
||||
if content.IsArray() {
|
||||
claudeContent := "[]"
|
||||
partCount := 0
|
||||
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Type == gjson.String {
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.String())
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
|
||||
partCount++
|
||||
return true
|
||||
}
|
||||
|
||||
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||
if claudePart != "" {
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
|
||||
partCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if partCount > 0 || len(content.Array()) == 0 {
|
||||
return claudeContent, true
|
||||
}
|
||||
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
if content.IsObject() {
|
||||
claudePart := convertOpenAIContentPartToClaudePart(content)
|
||||
if claudePart != "" {
|
||||
claudeContent := "[]"
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
|
||||
return claudeContent, true
|
||||
}
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_work",
|
||||
"arguments": "{\"a\":1}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": [
|
||||
{"type": "text", "text": "tool ok"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolResult := messages[1].Get("content.0")
|
||||
if got := toolResult.Get("type").String(); got != "tool_result" {
|
||||
t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got)
|
||||
}
|
||||
if got := toolResult.Get("tool_use_id").String(); got != "call_1" {
|
||||
t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got)
|
||||
}
|
||||
|
||||
toolContent := toolResult.Get("content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "text" {
|
||||
t.Fatalf("Expected first tool_result part type %q, got %q", "text", got)
|
||||
}
|
||||
if got := toolContent.Get("0.text").String(); got != "tool ok" {
|
||||
t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got)
|
||||
}
|
||||
if got := toolContent.Get("1.type").String(); got != "image" {
|
||||
t.Fatalf("Expected second tool_result part type %q, got %q", "image", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.type").String(); got != "base64" {
|
||||
t.Fatalf("Expected image source type %q, got %q", "base64", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.media_type").String(); got != "image/png" {
|
||||
t.Fatalf("Expected image media type %q, got %q", "image/png", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" {
|
||||
t.Fatalf("Unexpected base64 image data: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_work",
|
||||
"arguments": "{\"a\":1}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/tool.png"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolContent := messages[1].Get("content.0.content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "image" {
|
||||
t.Fatalf("Expected tool_result part type %q, got %q", "image", got)
|
||||
}
|
||||
if got := toolContent.Get("0.source.type").String(); got != "url" {
|
||||
t.Fatalf("Expected image source type %q, got %q", "url", got)
|
||||
}
|
||||
if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" {
|
||||
t.Fatalf("Unexpected image URL: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -43,23 +43,32 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Process system messages and convert them to input content format.
|
||||
systemsResult := rootResult.Get("system")
|
||||
if systemsResult.IsArray() {
|
||||
systemResults := systemsResult.Array()
|
||||
if systemsResult.Exists() {
|
||||
message := `{"type":"message","role":"developer","content":[]}`
|
||||
contentIndex := 0
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemResult := systemResults[i]
|
||||
systemTypeResult := systemResult.Get("type")
|
||||
if systemTypeResult.String() == "text" {
|
||||
text := systemResult.Get("text").String()
|
||||
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||
continue
|
||||
|
||||
appendSystemText := func(text string) {
|
||||
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||
return
|
||||
}
|
||||
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
}
|
||||
|
||||
if systemsResult.Type == gjson.String {
|
||||
appendSystemText(systemsResult.String())
|
||||
} else if systemsResult.IsArray() {
|
||||
systemResults := systemsResult.Array()
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemResult := systemResults[i]
|
||||
if systemResult.Get("type").String() == "text" {
|
||||
appendSystemText(systemResult.Get("text").String())
|
||||
}
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if contentIndex > 0 {
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantHasDeveloper bool
|
||||
wantTexts []string
|
||||
}{
|
||||
{
|
||||
name: "No system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasDeveloper: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": "",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasDeveloper: false,
|
||||
},
|
||||
{
|
||||
name: "String system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": "Be helpful",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasDeveloper: true,
|
||||
wantTexts: []string{"Be helpful"},
|
||||
},
|
||||
{
|
||||
name: "Array system field with filtered billing header",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": [
|
||||
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
|
||||
{"type": "text", "text": "Block 1"},
|
||||
{"type": "text", "text": "Block 2"}
|
||||
],
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasDeveloper: true,
|
||||
wantTexts: []string{"Block 1", "Block 2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
inputs := resultJSON.Get("input").Array()
|
||||
|
||||
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
|
||||
if hasDeveloper != tt.wantHasDeveloper {
|
||||
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
|
||||
}
|
||||
|
||||
if !tt.wantHasDeveloper {
|
||||
return
|
||||
}
|
||||
|
||||
content := inputs[0].Get("content").Array()
|
||||
if len(content) != len(tt.wantTexts) {
|
||||
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
|
||||
}
|
||||
|
||||
for i, wantText := range tt.wantTexts {
|
||||
if gotType := content[i].Get("type").String(); gotType != "input_text" {
|
||||
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
|
||||
}
|
||||
if gotText := content[i].Get("text").String(); gotText != wantText {
|
||||
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -141,7 +142,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||
{
|
||||
// Restore original tool name if shortened
|
||||
name := itemResult.Get("name").String()
|
||||
@@ -310,7 +311,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
}
|
||||
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
inputRaw := "{}"
|
||||
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
|
||||
|
||||
@@ -25,7 +25,12 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
// rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
if v := gjson.GetBytes(rawJSON, "service_tier"); v.Exists() {
|
||||
if v.String() != "priority" {
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
}
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -209,7 +210,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -116,6 +117,17 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
|
||||
// FunctionCallGroup represents a group of function calls and their responses
|
||||
type FunctionCallGroup struct {
|
||||
ResponsesNeeded int
|
||||
CallNames []string // ordered function call names for backfilling empty response names
|
||||
}
|
||||
|
||||
// backfillFunctionResponseName ensures that a functionResponse JSON object has a non-empty name,
|
||||
// falling back to fallbackName if the original is empty.
|
||||
func backfillFunctionResponseName(raw string, fallbackName string) string {
|
||||
name := gjson.Get(raw, "functionResponse.name").String()
|
||||
if strings.TrimSpace(name) == "" && fallbackName != "" {
|
||||
raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName)
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
@@ -165,31 +177,28 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
if len(responsePartsInThisContent) > 0 {
|
||||
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||
|
||||
// Check if any pending groups can be satisfied
|
||||
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||
group := pendingGroups[i]
|
||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
// Check if pending groups can be satisfied (FIFO: oldest group first)
|
||||
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
|
||||
group := pendingGroups[0]
|
||||
pendingGroups = pendingGroups[1:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
if !response.IsObject() {
|
||||
log.Warnf("failed to parse function response")
|
||||
continue
|
||||
}
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for ri, response := range groupResponses {
|
||||
if !response.IsObject() {
|
||||
log.Warnf("failed to parse function response")
|
||||
continue
|
||||
}
|
||||
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw)
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
|
||||
// Remove this group as it's been satisfied
|
||||
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||
break
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,15 +207,15 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
// If this is a model with function calls, create a new group
|
||||
if role == "model" {
|
||||
functionCallsCount := 0
|
||||
var callNames []string
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
functionCallsCount++
|
||||
callNames = append(callNames, part.Get("functionCall.name").String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if functionCallsCount > 0 {
|
||||
if len(callNames) > 0 {
|
||||
// Add the model content
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse model content")
|
||||
@@ -216,7 +225,8 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
ResponsesNeeded: functionCallsCount,
|
||||
ResponsesNeeded: len(callNames),
|
||||
CallNames: callNames,
|
||||
}
|
||||
pendingGroups = append(pendingGroups, group)
|
||||
} else {
|
||||
@@ -246,12 +256,13 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
for ri, response := range groupResponses {
|
||||
if !response.IsObject() {
|
||||
log.Warnf("failed to parse function response")
|
||||
continue
|
||||
}
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
|
||||
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw)
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
|
||||
@@ -224,7 +224,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.Set(data, "content_block.name", clientToolName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
@@ -343,7 +343,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
|
||||
toolIDCounter++
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
|
||||
inputRaw := "{}"
|
||||
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
|
||||
|
||||
@@ -5,9 +5,11 @@ package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -95,6 +97,71 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
|
||||
out = []byte(strJson)
|
||||
}
|
||||
|
||||
// Backfill empty functionResponse.name from the preceding functionCall.name.
|
||||
// Amp may send function responses with empty names; the Gemini API rejects these.
|
||||
out = backfillEmptyFunctionResponseNames(out)
|
||||
|
||||
out = common.AttachDefaultSafetySettings(out, "safetySettings")
|
||||
return out
|
||||
}
|
||||
|
||||
// backfillEmptyFunctionResponseNames walks the contents array and for each
|
||||
// model turn containing functionCall parts, records the call names in order.
|
||||
// For the immediately following user/function turn containing functionResponse
|
||||
// parts, any empty name is replaced with the corresponding call name.
|
||||
func backfillEmptyFunctionResponseNames(data []byte) []byte {
|
||||
contents := gjson.GetBytes(data, "contents")
|
||||
if !contents.Exists() {
|
||||
return data
|
||||
}
|
||||
|
||||
out := data
|
||||
var pendingCallNames []string
|
||||
|
||||
contents.ForEach(func(contentIdx, content gjson.Result) bool {
|
||||
role := content.Get("role").String()
|
||||
|
||||
// Collect functionCall names from model turns
|
||||
if role == "model" {
|
||||
var names []string
|
||||
content.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
names = append(names, part.Get("functionCall.name").String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
if len(names) > 0 {
|
||||
pendingCallNames = names
|
||||
} else {
|
||||
pendingCallNames = nil
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Backfill empty functionResponse names from pending call names
|
||||
if len(pendingCallNames) > 0 {
|
||||
ri := 0
|
||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||
if part.Get("functionResponse").Exists() {
|
||||
name := part.Get("functionResponse.name").String()
|
||||
if strings.TrimSpace(name) == "" {
|
||||
if ri < len(pendingCallNames) {
|
||||
out, _ = sjson.SetBytes(out,
|
||||
fmt.Sprintf("contents.%d.parts.%d.functionResponse.name", contentIdx.Int(), partIdx.Int()),
|
||||
pendingCallNames[ri])
|
||||
} else {
|
||||
log.Debugf("more function responses than calls at contents[%d], skipping name backfill", contentIdx.Int())
|
||||
}
|
||||
}
|
||||
ri++
|
||||
}
|
||||
return true
|
||||
})
|
||||
pendingCallNames = nil
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
193
internal/translator/gemini/gemini/gemini_gemini_request_test.go
Normal file
193
internal/translator/gemini/gemini/gemini_gemini_request_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBackfillEmptyFunctionResponseNames_Single(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := backfillEmptyFunctionResponseNames(input)
|
||||
|
||||
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
|
||||
if name != "Bash" {
|
||||
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackfillEmptyFunctionResponseNames_Parallel(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
|
||||
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
|
||||
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := backfillEmptyFunctionResponseNames(input)
|
||||
|
||||
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
|
||||
name1 := gjson.GetBytes(out, "contents.1.parts.1.functionResponse.name").String()
|
||||
if name0 != "Read" {
|
||||
t.Errorf("Expected first name 'Read', got '%s'", name0)
|
||||
}
|
||||
if name1 != "Grep" {
|
||||
t.Errorf("Expected second name 'Grep', got '%s'", name1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackfillEmptyFunctionResponseNames_PreservesExisting(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := backfillEmptyFunctionResponseNames(input)
|
||||
|
||||
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
|
||||
if name != "Bash" {
|
||||
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToGemini_BackfillsEmptyName(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertGeminiRequestToGemini("", input, false)
|
||||
|
||||
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
|
||||
if name != "Bash" {
|
||||
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackfillEmptyFunctionResponseNames_MoreResponsesThanCalls(t *testing.T) {
|
||||
// Extra responses beyond the call count should not panic and should be left unchanged.
|
||||
input := []byte(`{
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
|
||||
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := backfillEmptyFunctionResponseNames(input)
|
||||
|
||||
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
|
||||
if name0 != "Bash" {
|
||||
t.Errorf("Expected first name 'Bash', got '%s'", name0)
|
||||
}
|
||||
// Second response has no matching call, should remain empty
|
||||
name1 := gjson.GetBytes(out, "contents.1.parts.1.functionResponse.name").String()
|
||||
if name1 != "" {
|
||||
t.Errorf("Expected second name to remain empty, got '%s'", name1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackfillEmptyFunctionResponseNames_MultipleGroups(t *testing.T) {
|
||||
// Two sequential call/response groups should each get correct names.
|
||||
input := []byte(`{
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Read", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "content"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Grep", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "match"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := backfillEmptyFunctionResponseNames(input)
|
||||
|
||||
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
|
||||
name1 := gjson.GetBytes(out, "contents.3.parts.0.functionResponse.name").String()
|
||||
if name0 != "Read" {
|
||||
t.Errorf("Expected first group name 'Read', got '%s'", name0)
|
||||
}
|
||||
if name1 != "Grep" {
|
||||
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
|
||||
}
|
||||
}
|
||||
@@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
content := m.Get("content")
|
||||
|
||||
if (role == "system" || role == "developer") && len(arr) > 1 {
|
||||
// system -> system_instruction as a user message style
|
||||
// system -> systemInstruction as a user message style
|
||||
if content.Type == gjson.String {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String())
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
|
||||
systemPartIndex++
|
||||
} else if content.IsObject() && content.Get("type").String() == "text" {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
||||
systemPartIndex++
|
||||
} else if content.IsArray() {
|
||||
contents := content.Array()
|
||||
if len(contents) > 0 {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
for j := 0; j < len(contents); j++ {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
||||
systemPartIndex++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
if instructions := root.Get("instructions"); instructions.Exists() {
|
||||
systemInstr := `{"parts":[{"text":""}]}`
|
||||
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||
}
|
||||
|
||||
// Convert input messages to Gemini contents format
|
||||
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
if strings.EqualFold(itemRole, "system") {
|
||||
if contentArray := item.Get("content"); contentArray.Exists() {
|
||||
systemInstr := ""
|
||||
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
|
||||
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
|
||||
systemInstr = systemInstructionResult.Raw
|
||||
} else {
|
||||
systemInstr = `{"parts":[]}`
|
||||
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
if systemInstr != `{"parts":[]}` {
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||
}
|
||||
}
|
||||
continue
|
||||
@@ -237,6 +237,33 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
|
||||
}
|
||||
}
|
||||
case "input_audio":
|
||||
audioData := contentItem.Get("data").String()
|
||||
audioFormat := contentItem.Get("format").String()
|
||||
if audioData != "" {
|
||||
audioMimeMap := map[string]string{
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"webm": "audio/webm",
|
||||
"pcm16": "audio/pcm",
|
||||
"g711_ulaw": "audio/basic",
|
||||
"g711_alaw": "audio/basic",
|
||||
}
|
||||
mimeType := "audio/wav"
|
||||
if audioFormat != "" {
|
||||
if mapped, ok := audioMimeMap[audioFormat]; ok {
|
||||
mimeType = mapped
|
||||
} else {
|
||||
mimeType = "audio/" + audioFormat
|
||||
}
|
||||
}
|
||||
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
|
||||
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
|
||||
partJSON, _ = sjson.Set(partJSON, "inline_data.data", audioData)
|
||||
}
|
||||
}
|
||||
|
||||
if partJSON != "" {
|
||||
|
||||
@@ -183,7 +183,12 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
|
||||
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
|
||||
toolResultContent, toolResultContentRaw := convertClaudeToolResultContent(part.Get("content"))
|
||||
if toolResultContentRaw {
|
||||
toolResultJSON, _ = sjson.SetRaw(toolResultJSON, "content", toolResultContent)
|
||||
} else {
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", toolResultContent)
|
||||
}
|
||||
toolResults = append(toolResults, toolResultJSON)
|
||||
}
|
||||
return true
|
||||
@@ -374,21 +379,41 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func convertClaudeToolResultContentToString(content gjson.Result) string {
|
||||
func convertClaudeToolResultContent(content gjson.Result) (string, bool) {
|
||||
if !content.Exists() {
|
||||
return ""
|
||||
return "", false
|
||||
}
|
||||
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
return content.String(), false
|
||||
}
|
||||
|
||||
if content.IsArray() {
|
||||
var parts []string
|
||||
contentJSON := "[]"
|
||||
hasImagePart := false
|
||||
content.ForEach(func(_, item gjson.Result) bool {
|
||||
switch {
|
||||
case item.Type == gjson.String:
|
||||
parts = append(parts, item.String())
|
||||
text := item.String()
|
||||
parts = append(parts, text)
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent, _ = sjson.Set(textContent, "text", text)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
|
||||
case item.IsObject() && item.Get("type").String() == "text":
|
||||
text := item.Get("text").String()
|
||||
parts = append(parts, text)
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent, _ = sjson.Set(textContent, "text", text)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
|
||||
case item.IsObject() && item.Get("type").String() == "image":
|
||||
contentItem, ok := convertClaudeContentPart(item)
|
||||
if ok {
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
|
||||
hasImagePart = true
|
||||
} else {
|
||||
parts = append(parts, item.Raw)
|
||||
}
|
||||
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
|
||||
parts = append(parts, item.Get("text").String())
|
||||
default:
|
||||
@@ -397,19 +422,31 @@ func convertClaudeToolResultContentToString(content gjson.Result) string {
|
||||
return true
|
||||
})
|
||||
|
||||
if hasImagePart {
|
||||
return contentJSON, true
|
||||
}
|
||||
|
||||
joined := strings.Join(parts, "\n\n")
|
||||
if strings.TrimSpace(joined) != "" {
|
||||
return joined
|
||||
return joined, false
|
||||
}
|
||||
return content.Raw
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
if content.IsObject() {
|
||||
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
return text.String()
|
||||
if content.Get("type").String() == "image" {
|
||||
contentItem, ok := convertClaudeContentPart(content)
|
||||
if ok {
|
||||
contentJSON := "[]"
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
|
||||
return contentJSON, true
|
||||
}
|
||||
}
|
||||
return content.Raw
|
||||
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
return text.String(), false
|
||||
}
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
return content.Raw
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
@@ -488,6 +488,114 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultTextAndImageContent(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_1",
|
||||
"content": [
|
||||
{"type": "text", "text": "tool ok"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolContent := messages[1].Get("content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "text" {
|
||||
t.Fatalf("Expected first tool content type %q, got %q", "text", got)
|
||||
}
|
||||
if got := toolContent.Get("0.text").String(); got != "tool ok" {
|
||||
t.Fatalf("Expected first tool content text %q, got %q", "tool ok", got)
|
||||
}
|
||||
if got := toolContent.Get("1.type").String(); got != "image_url" {
|
||||
t.Fatalf("Expected second tool content type %q, got %q", "image_url", got)
|
||||
}
|
||||
if got := toolContent.Get("1.image_url.url").String(); got != "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" {
|
||||
t.Fatalf("Unexpected image_url: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultURLImageOnly(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_1",
|
||||
"content": {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": "https://example.com/tool.png"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolContent := messages[1].Get("content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "image_url" {
|
||||
t.Fatalf("Expected tool content type %q, got %q", "image_url", got)
|
||||
}
|
||||
if got := toolContent.Get("0.image_url.url").String(); got != "https://example.com/tool.png" {
|
||||
t.Fatalf("Unexpected image_url: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
|
||||
@@ -243,7 +243,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
// Send content_block_start for tool_use
|
||||
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
|
||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
|
||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
|
||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
|
||||
}
|
||||
@@ -414,7 +414,7 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
|
||||
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
|
||||
|
||||
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
||||
@@ -612,7 +612,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
toolCalls.ForEach(func(_, tc gjson.Result) bool {
|
||||
hasToolCall = true
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
|
||||
toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String()))
|
||||
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
|
||||
|
||||
argsStr := util.FixJSON(tc.Get("function.arguments").String())
|
||||
@@ -669,7 +669,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
hasToolCall = true
|
||||
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
|
||||
|
||||
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
||||
|
||||
24
internal/util/claude_tool_id.go
Normal file
24
internal/util/claude_tool_id.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
|
||||
claudeToolUseIDCounter uint64
|
||||
)
|
||||
|
||||
// SanitizeClaudeToolID ensures the given id conforms to Claude's
|
||||
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
|
||||
// replaced with '_'; an empty result gets a generated fallback.
|
||||
func SanitizeClaudeToolID(id string) string {
|
||||
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
|
||||
if s == "" {
|
||||
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -4,50 +4,25 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
|
||||
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
|
||||
// to route requests through the configured proxy server.
|
||||
func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
|
||||
var transport *http.Transport
|
||||
// Attempt to parse the proxy URL from the configuration.
|
||||
proxyURL, errParse := url.Parse(cfg.ProxyURL)
|
||||
if errParse == nil {
|
||||
// Handle different proxy schemes.
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
// Configure SOCKS5 proxy with optional authentication.
|
||||
var proxyAuth *proxy.Auth
|
||||
if proxyURL.User != nil {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return httpClient
|
||||
}
|
||||
// Set up a custom transport using the SOCKS5 dialer.
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
// Configure HTTP or HTTPS proxy.
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
if cfg == nil || httpClient == nil {
|
||||
return httpClient
|
||||
}
|
||||
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("%v", errBuild)
|
||||
}
|
||||
// If a new transport was created, apply it to the HTTP client.
|
||||
if transport != nil {
|
||||
httpClient.Transport = transport
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -75,6 +76,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
} else if resolvedAuthDir != "" {
|
||||
@@ -92,6 +94,17 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
}
|
||||
ctx := &synthesizer.SynthesisContext{
|
||||
Config: cfg,
|
||||
AuthDir: resolvedAuthDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||
}
|
||||
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
|
||||
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
|
||||
w.fileAuthsByPath[normalizedPath] = pathAuths
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -143,13 +156,14 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
if cfg == nil {
|
||||
if w.config == nil {
|
||||
log.Error("config is nil, cannot add or update client")
|
||||
w.clientsMutex.Unlock()
|
||||
return
|
||||
}
|
||||
if w.fileAuthsByPath == nil {
|
||||
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
|
||||
}
|
||||
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
||||
w.clientsMutex.Unlock()
|
||||
@@ -177,34 +191,86 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after add/update")
|
||||
w.triggerServerUpdate(cfg)
|
||||
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
|
||||
for id, a := range w.fileAuthsByPath[normalized] {
|
||||
oldByID[id] = a
|
||||
}
|
||||
|
||||
// Build synthesized auth entries for this single file only.
|
||||
sctx := &synthesizer.SynthesisContext{
|
||||
Config: w.config,
|
||||
AuthDir: w.authDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||
}
|
||||
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
|
||||
newByID := authSliceToMap(generated)
|
||||
if len(newByID) > 0 {
|
||||
w.fileAuthsByPath[normalized] = newByID
|
||||
} else {
|
||||
delete(w.fileAuthsByPath, normalized)
|
||||
}
|
||||
updates := w.computePerPathUpdatesLocked(oldByID, newByID)
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
||||
w.dispatchAuthUpdates(updates)
|
||||
}
|
||||
|
||||
func (w *Watcher) removeClient(path string) {
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
|
||||
for id, a := range w.fileAuthsByPath[normalized] {
|
||||
oldByID[id] = a
|
||||
}
|
||||
delete(w.lastAuthHashes, normalized)
|
||||
delete(w.lastAuthContents, normalized)
|
||||
delete(w.fileAuthsByPath, normalized)
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
updates := w.computePerPathUpdatesLocked(oldByID, map[string]*coreauth.Auth{})
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after removal")
|
||||
w.triggerServerUpdate(cfg)
|
||||
}
|
||||
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
||||
w.dispatchAuthUpdates(updates)
|
||||
}
|
||||
|
||||
func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate {
|
||||
if w.currentAuths == nil {
|
||||
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
updates := make([]AuthUpdate, 0, len(oldByID)+len(newByID))
|
||||
for id, newAuth := range newByID {
|
||||
existing, ok := w.currentAuths[id]
|
||||
if !ok {
|
||||
w.currentAuths[id] = newAuth.Clone()
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: newAuth.Clone()})
|
||||
continue
|
||||
}
|
||||
if !authEqual(existing, newAuth) {
|
||||
w.currentAuths[id] = newAuth.Clone()
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: newAuth.Clone()})
|
||||
}
|
||||
}
|
||||
for id := range oldByID {
|
||||
if _, stillExists := newByID[id]; stillExists {
|
||||
continue
|
||||
}
|
||||
delete(w.currentAuths, id)
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
|
||||
}
|
||||
return updates
|
||||
}
|
||||
|
||||
func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
|
||||
byID := make(map[string]*coreauth.Auth, len(auths))
|
||||
for _, a := range auths {
|
||||
if a == nil || strings.TrimSpace(a.ID) == "" {
|
||||
continue
|
||||
}
|
||||
byID[a.ID] = a
|
||||
}
|
||||
return byID
|
||||
}
|
||||
|
||||
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
var snapshotCoreAuthsFunc = snapshotCoreAuths
|
||||
|
||||
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||
w.clientsMutex.Lock()
|
||||
defer w.clientsMutex.Unlock()
|
||||
@@ -76,7 +78,11 @@ func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||
}
|
||||
|
||||
func (w *Watcher) refreshAuthState(force bool) {
|
||||
auths := w.SnapshotCoreAuths()
|
||||
w.clientsMutex.RLock()
|
||||
cfg := w.config
|
||||
authDir := w.authDir
|
||||
w.clientsMutex.RUnlock()
|
||||
auths := snapshotCoreAuthsFunc(cfg, authDir)
|
||||
w.clientsMutex.Lock()
|
||||
if len(w.runtimeAuths) > 0 {
|
||||
for _, a := range w.runtimeAuths {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
@@ -36,9 +37,6 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
||||
return out, nil
|
||||
}
|
||||
|
||||
now := ctx.Now
|
||||
cfg := ctx.Config
|
||||
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
@@ -52,99 +50,130 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
||||
if errRead != nil || len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
var metadata map[string]any
|
||||
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
||||
auths := synthesizeFileAuths(ctx, full, data)
|
||||
if len(auths) == 0 {
|
||||
continue
|
||||
}
|
||||
t, _ := metadata["type"].(string)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(t)
|
||||
if provider == "gemini" {
|
||||
provider = "gemini-cli"
|
||||
}
|
||||
label := provider
|
||||
if email, _ := metadata["email"].(string); email != "" {
|
||||
label = email
|
||||
}
|
||||
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
||||
id := full
|
||||
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
|
||||
if runtime.GOOS == "windows" {
|
||||
id = strings.ToLower(id)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if p, ok := metadata["proxy_url"].(string); ok {
|
||||
proxyURL = p
|
||||
}
|
||||
|
||||
prefix := ""
|
||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||
trimmed := strings.TrimSpace(rawPrefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||
prefix = trimmed
|
||||
}
|
||||
}
|
||||
|
||||
disabled, _ := metadata["disabled"].(bool)
|
||||
status := coreauth.StatusActive
|
||||
if disabled {
|
||||
status = coreauth.StatusDisabled
|
||||
}
|
||||
|
||||
// Read per-account excluded models from the OAuth JSON file
|
||||
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: label,
|
||||
Prefix: prefix,
|
||||
Status: status,
|
||||
Disabled: disabled,
|
||||
Attributes: map[string]string{
|
||||
"source": full,
|
||||
"path": full,
|
||||
},
|
||||
ProxyURL: proxyURL,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
// Read priority from auth file
|
||||
if rawPriority, ok := metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
a.Attributes["priority"] = strconv.Itoa(int(v))
|
||||
case string:
|
||||
priority := strings.TrimSpace(v)
|
||||
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
|
||||
a.Attributes["priority"] = priority
|
||||
}
|
||||
}
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, auths...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// SynthesizeAuthFile generates Auth entries for one auth JSON file payload.
|
||||
// It shares exactly the same mapping behavior as FileSynthesizer.Synthesize.
|
||||
func SynthesizeAuthFile(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
|
||||
return synthesizeFileAuths(ctx, fullPath, data)
|
||||
}
|
||||
|
||||
func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
|
||||
if ctx == nil || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := ctx.Now
|
||||
cfg := ctx.Config
|
||||
var metadata map[string]any
|
||||
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
||||
return nil
|
||||
}
|
||||
t, _ := metadata["type"].(string)
|
||||
if t == "" {
|
||||
return nil
|
||||
}
|
||||
provider := strings.ToLower(t)
|
||||
if provider == "gemini" {
|
||||
provider = "gemini-cli"
|
||||
}
|
||||
label := provider
|
||||
if email, _ := metadata["email"].(string); email != "" {
|
||||
label = email
|
||||
}
|
||||
// Use relative path under authDir as ID to stay consistent with the file-based token store.
|
||||
id := fullPath
|
||||
if strings.TrimSpace(ctx.AuthDir) != "" {
|
||||
if rel, errRel := filepath.Rel(ctx.AuthDir, fullPath); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
id = strings.ToLower(id)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if p, ok := metadata["proxy_url"].(string); ok {
|
||||
proxyURL = p
|
||||
}
|
||||
|
||||
prefix := ""
|
||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||
trimmed := strings.TrimSpace(rawPrefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||
prefix = trimmed
|
||||
}
|
||||
}
|
||||
|
||||
disabled, _ := metadata["disabled"].(bool)
|
||||
status := coreauth.StatusActive
|
||||
if disabled {
|
||||
status = coreauth.StatusDisabled
|
||||
}
|
||||
|
||||
// Read per-account excluded models from the OAuth JSON file.
|
||||
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: label,
|
||||
Prefix: prefix,
|
||||
Status: status,
|
||||
Disabled: disabled,
|
||||
Attributes: map[string]string{
|
||||
"source": fullPath,
|
||||
"path": fullPath,
|
||||
},
|
||||
ProxyURL: proxyURL,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
// Read priority from auth file.
|
||||
if rawPriority, ok := metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
a.Attributes["priority"] = strconv.Itoa(int(v))
|
||||
case string:
|
||||
priority := strings.TrimSpace(v)
|
||||
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
|
||||
a.Attributes["priority"] = priority
|
||||
}
|
||||
}
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
// For codex auth files, extract plan_type from the JWT id_token.
|
||||
if provider == "codex" {
|
||||
if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" {
|
||||
if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil {
|
||||
if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" {
|
||||
a.Attributes["plan_type"] = pt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
|
||||
}
|
||||
out := make([]*coreauth.Auth, 0, 1+len(virtuals))
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
return out
|
||||
}
|
||||
}
|
||||
return []*coreauth.Auth{a}
|
||||
}
|
||||
|
||||
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
|
||||
// It disables the primary auth and creates one virtual auth per project.
|
||||
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
|
||||
|
||||
@@ -45,6 +45,7 @@ type Watcher struct {
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastAuthContents map[string]*coreauth.Auth
|
||||
fileAuthsByPath map[string]map[string]*coreauth.Auth
|
||||
lastRemoveTimes map[string]time.Time
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
@@ -92,11 +93,12 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
|
||||
return nil, errNewWatcher
|
||||
}
|
||||
w := &Watcher{
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
reloadCallback: reloadCallback,
|
||||
watcher: watcher,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
reloadCallback: reloadCallback,
|
||||
watcher: watcher,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
|
||||
}
|
||||
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
||||
if store := sdkAuth.GetTokenStore(); store != nil {
|
||||
|
||||
@@ -406,8 +406,8 @@ func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
|
||||
|
||||
w.addOrUpdateClient(authFile)
|
||||
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", got)
|
||||
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||
t.Fatalf("expected no reload callback for auth update, got %d", got)
|
||||
}
|
||||
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
|
||||
normalized := w.normalizeAuthPath(authFile)
|
||||
@@ -436,8 +436,110 @@ func TestRemoveClientRemovesHash(t *testing.T) {
|
||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||
t.Fatal("expected hash to be removed after deletion")
|
||||
}
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", got)
|
||||
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||
t.Fatalf("expected no reload callback for auth removal, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"codex","email":"u@example.com"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to create auth file: %v", err)
|
||||
}
|
||||
|
||||
origSnapshot := snapshotCoreAuthsFunc
|
||||
var snapshotCalls int32
|
||||
snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth {
|
||||
atomic.AddInt32(&snapshotCalls, 1)
|
||||
return origSnapshot(cfg, authDir)
|
||||
}
|
||||
defer func() { snapshotCoreAuthsFunc = origSnapshot }()
|
||||
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
lastAuthContents: make(map[string]*coreauth.Auth),
|
||||
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||
|
||||
w.addOrUpdateClient(authFile)
|
||||
w.removeClient(authFile)
|
||||
|
||||
if got := atomic.LoadInt32(&snapshotCalls); got != 0 {
|
||||
t.Fatalf("expected auth file events to avoid full snapshot, got %d calls", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSliceToMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
valid1 := &coreauth.Auth{ID: "a"}
|
||||
valid2 := &coreauth.Auth{ID: "b"}
|
||||
dupOld := &coreauth.Auth{ID: "dup", Label: "old"}
|
||||
dupNew := &coreauth.Auth{ID: "dup", Label: "new"}
|
||||
empty := &coreauth.Auth{ID: " "}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in []*coreauth.Auth
|
||||
want map[string]*coreauth.Auth
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
in: nil,
|
||||
want: map[string]*coreauth.Auth{},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
in: []*coreauth.Auth{},
|
||||
want: map[string]*coreauth.Auth{},
|
||||
},
|
||||
{
|
||||
name: "filters invalid auths",
|
||||
in: []*coreauth.Auth{nil, empty},
|
||||
want: map[string]*coreauth.Auth{},
|
||||
},
|
||||
{
|
||||
name: "keeps valid auths",
|
||||
in: []*coreauth.Auth{valid1, nil, valid2},
|
||||
want: map[string]*coreauth.Auth{"a": valid1, "b": valid2},
|
||||
},
|
||||
{
|
||||
name: "last duplicate wins",
|
||||
in: []*coreauth.Auth{dupOld, dupNew},
|
||||
want: map[string]*coreauth.Auth{"dup": dupNew},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := authSliceToMap(tc.in)
|
||||
if len(tc.want) == 0 {
|
||||
if got == nil {
|
||||
t.Fatal("expected empty map, got nil")
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty map, got %#v", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(got) != len(tc.want) {
|
||||
t.Fatalf("unexpected map length: got %d, want %d", len(got), len(tc.want))
|
||||
}
|
||||
for id, wantAuth := range tc.want {
|
||||
gotAuth, ok := got[id]
|
||||
if !ok {
|
||||
t.Fatalf("missing id %q in result map", id)
|
||||
}
|
||||
if !authEqual(gotAuth, wantAuth) {
|
||||
t.Fatalf("unexpected auth for id %q: got %#v, want %#v", id, gotAuth, wantAuth)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -695,8 +797,8 @@ func TestHandleEventRemovesAuthFile(t *testing.T) {
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", reloads)
|
||||
if atomic.LoadInt32(&reloads) != 0 {
|
||||
t.Fatalf("expected no reload callback for auth removal, got %d", reloads)
|
||||
}
|
||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||
t.Fatal("expected hash entry to be removed")
|
||||
@@ -893,8 +995,8 @@ func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
|
||||
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected auth write to trigger reload callback, got %d", reloads)
|
||||
if atomic.LoadInt32(&reloads) != 0 {
|
||||
t.Fatalf("expected auth write to avoid global reload callback, got %d", reloads)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -990,8 +1092,8 @@ func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
|
||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads)
|
||||
if atomic.LoadInt32(&reloads) != 0 {
|
||||
t.Fatalf("expected changed atomic replace to avoid global reload, got %d", reloads)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1045,8 +1147,8 @@ func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
|
||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected known remove to trigger reload, got %d", reloads)
|
||||
if atomic.LoadInt32(&reloads) != 0 {
|
||||
t.Fatalf("expected known remove to avoid global reload, got %d", reloads)
|
||||
}
|
||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||
t.Fatal("expected known auth hash to be deleted")
|
||||
|
||||
151
sdk/api/handlers/claude/gitlab_duo_handler_test.go
Normal file
151
sdk/api/handlers/claude/gitlab_duo_handler_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestClaudeMessagesWithGitLabDuoAnthropicGateway(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var gotPath, gotAuthHeader, gotRealmHeader string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
manager, _ := registerGitLabDuoAnthropicAuth(t, upstream.URL)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewClaudeCodeAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/messages", h.ClaudeMessages)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{
|
||||
"model":"claude-sonnet-4-5",
|
||||
"max_tokens":128,
|
||||
"messages":[{"role":"user","content":"list files"}],
|
||||
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}]
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Anthropic-Version", "2023-06-01")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("x-gitlab-realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"tool_use"`) {
|
||||
t.Fatalf("expected tool_use response, got %s", resp.Body.String())
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"Bash"`) {
|
||||
t.Fatalf("expected Bash tool in response, got %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeMessagesStreamWithGitLabDuoAnthropicGateway(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var gotPath string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: message_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from duo\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_stop\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
manager, _ := registerGitLabDuoAnthropicAuth(t, upstream.URL)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewClaudeCodeAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/messages", h.ClaudeMessages)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{
|
||||
"model":"claude-sonnet-4-5",
|
||||
"stream":true,
|
||||
"max_tokens":64,
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Anthropic-Version", "2023-06-01")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if got := resp.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Fatalf("content-type = %q, want text/event-stream", got)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), "event: content_block_delta") {
|
||||
t.Fatalf("expected streamed claude event, got %s", resp.Body.String())
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), "hello from duo") {
|
||||
t.Fatalf("expected streamed text, got %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func registerGitLabDuoAnthropicAuth(t *testing.T, upstreamURL string) (*coreauth.Manager, string) {
|
||||
t.Helper()
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(runtimeexecutor.NewGitLabExecutor(&internalconfig.Config{}))
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-duo-claude-handler-test",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": upstreamURL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
registered, err := manager.Register(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(registered.ID, registered.Provider, runtimeexecutor.GitLabModelsFromAuth(registered))
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(registered.ID)
|
||||
})
|
||||
return manager, registered.ID
|
||||
}
|
||||
143
sdk/api/handlers/openai/gitlab_duo_handler_test.go
Normal file
143
sdk/api/handlers/openai/gitlab_duo_handler_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestOpenAIChatCompletionsWithGitLabDuoOpenAIGateway(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var gotPath, gotAuthHeader, gotRealmHeader string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from duo openai\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"status\":\"completed\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from duo openai\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
manager := registerGitLabDuoOpenAIAuth(t, upstream.URL)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/chat/completions", h.ChatCompletions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{
|
||||
"model":"gpt-5-codex",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("x-gitlab-realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"content":"hello from duo openai"`) {
|
||||
t.Fatalf("expected translated chat completion, got %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesStreamWithGitLabDuoOpenAIGateway(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var gotPath, gotAuthHeader string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"streamed duo output\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"status\":\"completed\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"streamed duo output\"}]}],\"usage\":{\"input_tokens\":10,\"output_tokens\":3,\"total_tokens\":13}}}\n\n"))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
manager := registerGitLabDuoOpenAIAuth(t, upstream.URL)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/responses", h.Responses)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{
|
||||
"model":"gpt-5-codex",
|
||||
"stream":true,
|
||||
"input":"hello"
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if got := resp.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Fatalf("content-type = %q, want text/event-stream", got)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"type":"response.output_text.delta"`) {
|
||||
t.Fatalf("expected streamed responses delta, got %s", resp.Body.String())
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"type":"response.completed"`) {
|
||||
t.Fatalf("expected streamed responses completion, got %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func registerGitLabDuoOpenAIAuth(t *testing.T, upstreamURL string) *coreauth.Manager {
|
||||
t.Helper()
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(runtimeexecutor.NewGitLabExecutor(&internalconfig.Config{}))
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-duo-openai-handler-test",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": upstreamURL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-5-codex",
|
||||
},
|
||||
}
|
||||
registered, err := manager.Register(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(registered.ID, registered.Provider, runtimeexecutor.GitLabModelsFromAuth(registered))
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(registered.ID)
|
||||
})
|
||||
return manager
|
||||
}
|
||||
@@ -14,7 +14,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -26,11 +30,12 @@ const (
|
||||
wsRequestTypeAppend = "response.append"
|
||||
wsEventTypeError = "error"
|
||||
wsEventTypeCompleted = "response.completed"
|
||||
wsEventTypeDone = "response.done"
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
wsPayloadLogMaxSize = 2048
|
||||
wsBodyLogMaxSize = 64 * 1024
|
||||
wsBodyLogTruncated = "\n[websocket log truncated]\n"
|
||||
)
|
||||
|
||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||
@@ -101,11 +106,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
// )
|
||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||
|
||||
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
||||
allowIncrementalInputWithPreviousResponseID := false
|
||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
}
|
||||
} else {
|
||||
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
if requestModelName == "" {
|
||||
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
}
|
||||
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||
}
|
||||
|
||||
var requestJSON []byte
|
||||
@@ -140,6 +151,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
|
||||
requestJSON = updated
|
||||
}
|
||||
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
|
||||
updatedLastRequest = updated
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
lastResponseOutput = []byte("[]")
|
||||
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
|
||||
wsTerminateErr = errWrite
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
@@ -340,6 +367,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
||||
if h == nil || h.AuthManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
resolvedModelName := modelName
|
||||
initialSuffix := thinking.ParseSuffix(modelName)
|
||||
if initialSuffix.ModelName == "auto" {
|
||||
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
||||
if initialSuffix.HasSuffix {
|
||||
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
||||
} else {
|
||||
resolvedModelName = resolvedBase
|
||||
}
|
||||
} else {
|
||||
resolvedModelName = util.ResolveAutoModel(modelName)
|
||||
}
|
||||
|
||||
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||
providers := util.GetProviderName(baseModel)
|
||||
if len(providers) == 0 && baseModel != resolvedModelName {
|
||||
providers = util.GetProviderName(resolvedModelName)
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for i := 0; i < len(providers); i++ {
|
||||
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
providerSet[providerKey] = struct{}{}
|
||||
}
|
||||
if len(providerSet) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
modelKey := baseModel
|
||||
if modelKey == "" {
|
||||
modelKey = strings.TrimSpace(resolvedModelName)
|
||||
}
|
||||
registryRef := registry.GetGlobalRegistry()
|
||||
now := time.Now()
|
||||
auths := h.AuthManager.List()
|
||||
for i := 0; i < len(auths); i++ {
|
||||
auth := auths[i]
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
||||
if _, ok := providerSet[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
|
||||
continue
|
||||
}
|
||||
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
|
||||
continue
|
||||
}
|
||||
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
|
||||
return false
|
||||
}
|
||||
if modelName != "" && len(auth.ModelStates) > 0 {
|
||||
state, ok := auth.ModelStates[modelName]
|
||||
if (!ok || state == nil) && modelName != "" {
|
||||
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
|
||||
if baseModel != "" && baseModel != modelName {
|
||||
state, ok = auth.ModelStates[baseModel]
|
||||
}
|
||||
}
|
||||
if ok && state != nil {
|
||||
if state.Status == coreauth.StatusDisabled {
|
||||
return false
|
||||
}
|
||||
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
|
||||
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
||||
return false
|
||||
}
|
||||
generateResult := gjson.GetBytes(rawJSON, "generate")
|
||||
return generateResult.Exists() && !generateResult.Bool()
|
||||
}
|
||||
|
||||
func writeResponsesWebsocketSyntheticPrewarm(
|
||||
c *gin.Context,
|
||||
conn *websocket.Conn,
|
||||
requestJSON []byte,
|
||||
wsBodyLog *strings.Builder,
|
||||
sessionID string,
|
||||
) error {
|
||||
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
|
||||
if errPayloads != nil {
|
||||
return errPayloads
|
||||
}
|
||||
for i := 0; i < len(payloads); i++ {
|
||||
markAPIResponseTimestamp(c)
|
||||
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
// sessionID,
|
||||
// websocket.TextMessage,
|
||||
// websocketPayloadEventType(payloads[i]),
|
||||
// websocketPayloadPreview(payloads[i]),
|
||||
// )
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(payloads[i]),
|
||||
errWrite,
|
||||
)
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
|
||||
responseID := "resp_prewarm_" + uuid.NewString()
|
||||
createdAt := time.Now().Unix()
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
|
||||
|
||||
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||
var errSet error
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
if modelName != "" {
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
if modelName != "" {
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
return [][]byte{createdPayload, completedPayload}, nil
|
||||
}
|
||||
|
||||
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||
existingRaw = strings.TrimSpace(existingRaw)
|
||||
appendRaw = strings.TrimSpace(appendRaw)
|
||||
@@ -469,9 +682,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
for i := range payloads {
|
||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||
if eventType == wsEventTypeCompleted {
|
||||
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
||||
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
||||
|
||||
completed = true
|
||||
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||
}
|
||||
@@ -554,65 +764,134 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
|
||||
}
|
||||
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
payload := map[string]any{
|
||||
"type": wsEventTypeError,
|
||||
"status": status,
|
||||
payload := []byte(`{}`)
|
||||
var errSet error
|
||||
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
payload, errSet = sjson.SetBytes(payload, "status", status)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
|
||||
if errMsg != nil && errMsg.Addon != nil {
|
||||
headers := map[string]any{}
|
||||
headers := []byte(`{}`)
|
||||
hasHeaders := false
|
||||
for key, values := range errMsg.Addon {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
headers[key] = values[0]
|
||||
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
|
||||
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
hasHeaders = true
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
payload["headers"] = headers
|
||||
}
|
||||
}
|
||||
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
var decoded map[string]any
|
||||
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
||||
if inner, ok := decoded["error"]; ok {
|
||||
payload["error"] = inner
|
||||
} else {
|
||||
payload["error"] = decoded
|
||||
if hasHeaders {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := payload["error"]; !ok {
|
||||
payload["error"] = map[string]any{
|
||||
"type": "server_error",
|
||||
"message": errText,
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
errorNode := gjson.GetBytes(body, "error")
|
||||
if errorNode.Exists() {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
|
||||
} else {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "error", body)
|
||||
}
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !gjson.GetBytes(payload, "error").Exists() {
|
||||
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
return data, conn.WriteMessage(websocket.TextMessage, data)
|
||||
|
||||
return payload, conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
|
||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
if builder.Len() >= wsBodyLogMaxSize {
|
||||
return
|
||||
}
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
if !appendWebsocketLogString(builder, "\n") {
|
||||
return
|
||||
}
|
||||
}
|
||||
builder.WriteString("websocket.")
|
||||
builder.WriteString(eventType)
|
||||
builder.WriteString("\n")
|
||||
builder.Write(trimmedPayload)
|
||||
builder.WriteString("\n")
|
||||
if !appendWebsocketLogString(builder, "websocket.") {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogString(builder, eventType) {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogString(builder, "\n") {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
|
||||
appendWebsocketLogString(builder, wsBodyLogTruncated)
|
||||
return
|
||||
}
|
||||
appendWebsocketLogString(builder, "\n")
|
||||
}
|
||||
|
||||
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
|
||||
if builder == nil {
|
||||
return false
|
||||
}
|
||||
remaining := wsBodyLogMaxSize - builder.Len()
|
||||
if remaining <= 0 {
|
||||
return false
|
||||
}
|
||||
if len(value) <= remaining {
|
||||
builder.WriteString(value)
|
||||
return true
|
||||
}
|
||||
builder.WriteString(value[:remaining])
|
||||
return false
|
||||
}
|
||||
|
||||
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
|
||||
if builder == nil {
|
||||
return false
|
||||
}
|
||||
remaining := wsBodyLogMaxSize - builder.Len()
|
||||
if remaining <= 0 {
|
||||
return false
|
||||
}
|
||||
if len(value) <= remaining {
|
||||
builder.Write(value)
|
||||
return true
|
||||
}
|
||||
limit := remaining - reserveForSuffix
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
if limit > len(value) {
|
||||
limit = len(value)
|
||||
}
|
||||
builder.Write(value[:limit])
|
||||
return false
|
||||
}
|
||||
|
||||
func websocketPayloadEventType(payload []byte) string {
|
||||
|
||||
@@ -2,15 +2,57 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type websocketCaptureExecutor struct {
|
||||
streamCalls int
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.streamCalls++
|
||||
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
@@ -224,6 +266,33 @@ func TestAppendWebsocketEvent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
|
||||
|
||||
appendWebsocketEvent(&builder, "request", payload)
|
||||
|
||||
got := builder.String()
|
||||
if len(got) > wsBodyLogMaxSize {
|
||||
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
|
||||
}
|
||||
if !strings.Contains(got, wsBodyLogTruncated) {
|
||||
t.Fatalf("expected truncation marker in body log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
|
||||
initial := builder.String()
|
||||
|
||||
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
|
||||
|
||||
if builder.String() != initial {
|
||||
t.Fatalf("builder grew after reaching limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -247,3 +316,206 @@ func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
serverErrCh <- errClose
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
ctx.Request = r
|
||||
|
||||
data := make(chan []byte, 1)
|
||||
errCh := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
|
||||
close(data)
|
||||
close(errCh)
|
||||
|
||||
var bodyLog strings.Builder
|
||||
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
|
||||
ctx,
|
||||
conn,
|
||||
func(...interface{}) {},
|
||||
data,
|
||||
errCh,
|
||||
&bodyLog,
|
||||
"session-1",
|
||||
)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
|
||||
serverErrCh <- errors.New("completed output not captured")
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read websocket message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if strings.Contains(string(payload), "response.done") {
|
||||
t.Fatalf("payload unexpectedly rewrote completed event: %s", payload)
|
||||
}
|
||||
|
||||
if errServer := <-serverErrCh; errServer != nil {
|
||||
t.Fatalf("server error: %v", errServer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-ws",
|
||||
Provider: "test-provider",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
|
||||
t.Fatalf("expected websocket-capable upstream for test-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
executor := &websocketCaptureExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
|
||||
if errWrite != nil {
|
||||
t.Fatalf("write prewarm websocket message: %v", errWrite)
|
||||
}
|
||||
|
||||
_, createdPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read prewarm created message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
|
||||
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
|
||||
}
|
||||
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
|
||||
if prewarmResponseID == "" {
|
||||
t.Fatalf("prewarm response id is empty")
|
||||
}
|
||||
if executor.streamCalls != 0 {
|
||||
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
|
||||
}
|
||||
|
||||
_, completedPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read prewarm completed message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
|
||||
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
|
||||
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
|
||||
}
|
||||
|
||||
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
|
||||
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
|
||||
if errWrite != nil {
|
||||
t.Fatalf("write follow-up websocket message: %v", errWrite)
|
||||
}
|
||||
|
||||
_, upstreamPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read upstream completed message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if executor.streamCalls != 1 {
|
||||
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
|
||||
}
|
||||
if len(executor.payloads) != 1 {
|
||||
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
|
||||
}
|
||||
forwarded := executor.payloads[0]
|
||||
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
|
||||
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
|
||||
}
|
||||
if gjson.GetBytes(forwarded, "generate").Exists() {
|
||||
t.Fatalf("generate leaked upstream: %s", forwarded)
|
||||
}
|
||||
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
|
||||
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
|
||||
}
|
||||
input := gjson.GetBytes(forwarded, "input").Array()
|
||||
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected forwarded input: %s", forwarded)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,5 +287,8 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
Attributes: map[string]string{
|
||||
"plan_type": planType,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
485
sdk/auth/gitlab.go
Normal file
485
sdk/auth/gitlab.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
gitLabLoginModeMetadataKey = "login_mode"
|
||||
gitLabLoginModeOAuth = "oauth"
|
||||
gitLabLoginModePAT = "pat"
|
||||
gitLabBaseURLMetadataKey = "base_url"
|
||||
gitLabOAuthClientIDMetadataKey = "oauth_client_id"
|
||||
gitLabOAuthClientSecretMetadataKey = "oauth_client_secret"
|
||||
gitLabPersonalAccessTokenMetadataKey = "personal_access_token"
|
||||
)
|
||||
|
||||
var gitLabRefreshLead = 5 * time.Minute
|
||||
|
||||
type GitLabAuthenticator struct {
|
||||
CallbackPort int
|
||||
}
|
||||
|
||||
func NewGitLabAuthenticator() *GitLabAuthenticator {
|
||||
return &GitLabAuthenticator{CallbackPort: gitlabauth.DefaultCallbackPort}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) Provider() string {
|
||||
return "gitlab"
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) RefreshLead() *time.Duration {
|
||||
return &gitLabRefreshLead
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(opts.Metadata[gitLabLoginModeMetadataKey])) {
|
||||
case "", gitLabLoginModeOAuth:
|
||||
return a.loginOAuth(ctx, cfg, opts)
|
||||
case gitLabLoginModePAT:
|
||||
return a.loginPAT(ctx, cfg, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("gitlab auth: unsupported login mode %q", opts.Metadata[gitLabLoginModeMetadataKey])
|
||||
}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) loginOAuth(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
client := gitlabauth.NewAuthClient(cfg)
|
||||
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
|
||||
clientID, err := a.requireInput(opts, gitLabOAuthClientIDMetadataKey, "Enter GitLab OAuth application client ID: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSecret, err := a.optionalInput(opts, gitLabOAuthClientSecretMetadataKey, "Enter GitLab OAuth application client secret (press Enter for public PKCE app): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
redirectURI := gitlabauth.RedirectURL(callbackPort)
|
||||
|
||||
pkceCodes, err := gitlabauth.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state, err := misc.GenerateRandomState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab state generation failed: %w", err)
|
||||
}
|
||||
|
||||
oauthServer := gitlabauth.NewOAuthServer(callbackPort)
|
||||
if err := oauthServer.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if stopErr := oauthServer.Stop(stopCtx); stopErr != nil {
|
||||
log.Warnf("gitlab oauth server stop error: %v", stopErr)
|
||||
}
|
||||
}()
|
||||
|
||||
authURL, err := client.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !opts.NoBrowser {
|
||||
fmt.Println("Opening browser for GitLab Duo authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for GitLab OAuth callback...")
|
||||
|
||||
callbackCh := make(chan *gitlabauth.OAuthResult, 1)
|
||||
callbackErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
result, waitErr := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if waitErr != nil {
|
||||
callbackErrCh <- waitErr
|
||||
return
|
||||
}
|
||||
callbackCh <- result
|
||||
}()
|
||||
|
||||
var result *gitlabauth.OAuthResult
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
return nil, err
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
input, promptErr := opts.Prompt("Paste the GitLab callback URL (or press Enter to keep waiting): ")
|
||||
if promptErr != nil {
|
||||
return nil, promptErr
|
||||
}
|
||||
parsed, parseErr := misc.ParseOAuthCallback(input)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
result = &gitlabauth.OAuthResult{
|
||||
Code: parsed.Code,
|
||||
State: parsed.State,
|
||||
Error: parsed.Error,
|
||||
}
|
||||
break waitForCallback
|
||||
}
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("gitlab oauth returned error: %s", result.Error)
|
||||
}
|
||||
if result.State != state {
|
||||
return nil, fmt.Errorf("gitlab auth: state mismatch")
|
||||
}
|
||||
|
||||
tokenResp, err := client.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, result.Code, pkceCodes.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("gitlab auth: missing access token")
|
||||
}
|
||||
|
||||
user, err := client.GetCurrentUser(ctx, baseURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
direct, err := client.FetchDirectAccess(ctx, baseURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata[gitLabOAuthClientIDMetadataKey] = clientID
|
||||
if strings.TrimSpace(clientSecret) != "" {
|
||||
metadata[gitLabOAuthClientSecretMetadataKey] = clientSecret
|
||||
}
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
|
||||
fmt.Println("GitLab Duo authentication successful")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: identifier,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) loginPAT(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
client := gitlabauth.NewAuthClient(cfg)
|
||||
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
|
||||
token, err := a.requireInput(opts, gitLabPersonalAccessTokenMetadataKey, "Enter GitLab personal access token: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := client.GetCurrentUser(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = client.GetPersonalAccessTokenSelf(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
direct, err := client.FetchDirectAccess(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
|
||||
metadata["auth_kind"] = "personal_access_token"
|
||||
metadata[gitLabPersonalAccessTokenMetadataKey] = strings.TrimSpace(token)
|
||||
metadata["token_preview"] = maskGitLabToken(token)
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
|
||||
fmt.Println("GitLab Duo PAT authentication successful")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: identifier + " (PAT)",
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
|
||||
metadata := map[string]any{
|
||||
"type": "gitlab",
|
||||
"auth_method": strings.TrimSpace(mode),
|
||||
gitLabBaseURLMetadataKey: gitlabauth.NormalizeBaseURL(baseURL),
|
||||
"last_refresh": time.Now().UTC().Format(time.RFC3339),
|
||||
"refresh_interval_seconds": 240,
|
||||
}
|
||||
if tokenResp != nil {
|
||||
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
|
||||
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
|
||||
metadata["refresh_token"] = refreshToken
|
||||
}
|
||||
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
|
||||
metadata["token_type"] = tokenType
|
||||
}
|
||||
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
|
||||
metadata["scope"] = scope
|
||||
}
|
||||
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
|
||||
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
mergeGitLabDirectAccessMetadata(metadata, direct)
|
||||
return metadata
|
||||
}
|
||||
|
||||
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
|
||||
if metadata == nil || direct == nil {
|
||||
return
|
||||
}
|
||||
if base := strings.TrimSpace(direct.BaseURL); base != "" {
|
||||
metadata["duo_gateway_base_url"] = base
|
||||
}
|
||||
if token := strings.TrimSpace(direct.Token); token != "" {
|
||||
metadata["duo_gateway_token"] = token
|
||||
}
|
||||
if direct.ExpiresAt > 0 {
|
||||
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
|
||||
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
|
||||
now := time.Now().UTC()
|
||||
if ttl := expiry.Sub(now); ttl > 0 {
|
||||
interval := int(ttl.Seconds()) / 2
|
||||
switch {
|
||||
case interval < 60:
|
||||
interval = 60
|
||||
case interval > 240:
|
||||
interval = 240
|
||||
}
|
||||
metadata["refresh_interval_seconds"] = interval
|
||||
}
|
||||
}
|
||||
if len(direct.Headers) > 0 {
|
||||
headers := make(map[string]string, len(direct.Headers))
|
||||
for key, value := range direct.Headers {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
headers[key] = value
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
metadata["duo_gateway_headers"] = headers
|
||||
}
|
||||
}
|
||||
if direct.ModelDetails != nil {
|
||||
modelDetails := map[string]any{}
|
||||
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
|
||||
modelDetails["model_provider"] = provider
|
||||
metadata["model_provider"] = provider
|
||||
}
|
||||
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
|
||||
modelDetails["model_name"] = model
|
||||
metadata["model_name"] = model
|
||||
}
|
||||
if len(modelDetails) > 0 {
|
||||
metadata["model_details"] = modelDetails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) resolveString(opts *LoginOptions, key, fallback string) string {
|
||||
if opts != nil && opts.Metadata != nil {
|
||||
if value := strings.TrimSpace(opts.Metadata[key]); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
for _, envKey := range gitLabEnvKeys(key) {
|
||||
if raw, ok := os.LookupEnv(envKey); ok {
|
||||
if trimmed := strings.TrimSpace(raw); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(fallback) != "" {
|
||||
return fallback
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) requireInput(opts *LoginOptions, key, prompt string) (string, error) {
|
||||
if value := a.resolveString(opts, key, ""); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
value, err := opts.Prompt(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("gitlab auth: missing required %s", key)
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) optionalInput(opts *LoginOptions, key, prompt string) (string, error) {
|
||||
if value := a.resolveString(opts, key, ""); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
value, err := opts.Prompt(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func primaryGitLabEmail(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
if value := strings.TrimSpace(user.Email); value != "" {
|
||||
return value
|
||||
}
|
||||
return strings.TrimSpace(user.PublicEmail)
|
||||
}
|
||||
|
||||
func gitLabAccountIdentifier(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return "user"
|
||||
}
|
||||
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return "user"
|
||||
}
|
||||
|
||||
func sanitizeGitLabFileName(value string) string {
|
||||
value = strings.TrimSpace(strings.ToLower(value))
|
||||
if value == "" {
|
||||
return "user"
|
||||
}
|
||||
var builder strings.Builder
|
||||
lastDash := false
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r >= '0' && r <= '9':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
default:
|
||||
if !lastDash {
|
||||
builder.WriteRune('-')
|
||||
lastDash = true
|
||||
}
|
||||
}
|
||||
}
|
||||
result := strings.Trim(builder.String(), "-")
|
||||
if result == "" {
|
||||
return "user"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func maskGitLabToken(token string) string {
|
||||
trimmed := strings.TrimSpace(token)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if len(trimmed) <= 8 {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
|
||||
}
|
||||
|
||||
func gitLabEnvKeys(key string) []string {
|
||||
switch strings.TrimSpace(key) {
|
||||
case gitLabBaseURLMetadataKey:
|
||||
return []string{"GITLAB_BASE_URL"}
|
||||
case gitLabOAuthClientIDMetadataKey:
|
||||
return []string{"GITLAB_OAUTH_CLIENT_ID"}
|
||||
case gitLabOAuthClientSecretMetadataKey:
|
||||
return []string{"GITLAB_OAUTH_CLIENT_SECRET"}
|
||||
case gitLabPersonalAccessTokenMetadataKey:
|
||||
return []string{"GITLAB_PERSONAL_ACCESS_TOKEN"}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
66
sdk/auth/gitlab_test.go
Normal file
66
sdk/auth/gitlab_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestGitLabAuthenticatorLoginPAT(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v4/user":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": 42,
|
||||
"username": "duo-user",
|
||||
"email": "duo@example.com",
|
||||
"name": "Duo User",
|
||||
})
|
||||
case "/api/v4/personal_access_tokens/self":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": 5,
|
||||
"name": "CLIProxyAPI",
|
||||
"scopes": []string{"api"},
|
||||
})
|
||||
case "/api/v4/code_suggestions/direct_access":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1710003600,
|
||||
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
authenticator := NewGitLabAuthenticator()
|
||||
record, err := authenticator.Login(context.Background(), &config.Config{}, &LoginOptions{
|
||||
Metadata: map[string]string{
|
||||
"login_mode": "pat",
|
||||
"base_url": srv.URL,
|
||||
"personal_access_token": "glpat-test-token",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
if record.Provider != "gitlab" {
|
||||
t.Fatalf("expected gitlab provider, got %q", record.Provider)
|
||||
}
|
||||
if got := record.Metadata["model_name"]; got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected discovered model, got %#v", got)
|
||||
}
|
||||
if got := record.Metadata["auth_kind"]; got != "personal_access_token" {
|
||||
t.Fatalf("expected personal_access_token auth kind, got %#v", got)
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ func init() {
|
||||
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
|
||||
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
@@ -134,6 +134,7 @@ type Manager struct {
|
||||
hook Hook
|
||||
mu sync.RWMutex
|
||||
auths map[string]*Auth
|
||||
scheduler *authScheduler
|
||||
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
||||
providerOffsets map[string]int
|
||||
|
||||
@@ -149,6 +150,9 @@ type Manager struct {
|
||||
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
|
||||
apiKeyModelAlias atomic.Value
|
||||
|
||||
// modelPoolOffsets tracks per-auth alias pool rotation state.
|
||||
modelPoolOffsets map[string]int
|
||||
|
||||
// runtimeConfig stores the latest application config for request-time decisions.
|
||||
// It is initialized in NewManager; never Load() before first Store().
|
||||
runtimeConfig atomic.Value
|
||||
@@ -176,14 +180,59 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
hook: hook,
|
||||
auths: make(map[string]*Auth),
|
||||
providerOffsets: make(map[string]int),
|
||||
modelPoolOffsets: make(map[string]int),
|
||||
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
|
||||
}
|
||||
// atomic.Value requires non-nil initial value.
|
||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
|
||||
manager.scheduler = newAuthScheduler(selector)
|
||||
return manager
|
||||
}
|
||||
|
||||
func isBuiltInSelector(selector Selector) bool {
|
||||
switch selector.(type) {
|
||||
case *RoundRobinSelector, *FillFirstSelector:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
|
||||
if m == nil || m.scheduler == nil {
|
||||
return
|
||||
}
|
||||
m.scheduler.rebuild(auths)
|
||||
}
|
||||
|
||||
func (m *Manager) syncScheduler() {
|
||||
if m == nil || m.scheduler == nil {
|
||||
return
|
||||
}
|
||||
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
||||
}
|
||||
|
||||
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
|
||||
// supportedModelSet is rebuilt from the current global model registry state.
|
||||
// This must be called after models have been registered for a newly added auth,
|
||||
// because the initial scheduler.upsertAuth during Register/Update runs before
|
||||
// registerModelsForAuth and therefore snapshots an empty model set.
|
||||
func (m *Manager) RefreshSchedulerEntry(authID string) {
|
||||
if m == nil || m.scheduler == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.RLock()
|
||||
auth, ok := m.auths[authID]
|
||||
if !ok || auth == nil {
|
||||
m.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
snapshot := auth.Clone()
|
||||
m.mu.RUnlock()
|
||||
m.scheduler.upsertAuth(snapshot)
|
||||
}
|
||||
|
||||
func (m *Manager) SetSelector(selector Selector) {
|
||||
if m == nil {
|
||||
return
|
||||
@@ -194,6 +243,10 @@ func (m *Manager) SetSelector(selector Selector) {
|
||||
m.mu.Lock()
|
||||
m.selector = selector
|
||||
m.mu.Unlock()
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.setSelector(selector)
|
||||
m.syncScheduler()
|
||||
}
|
||||
}
|
||||
|
||||
// SetStore swaps the underlying persistence store.
|
||||
@@ -251,16 +304,323 @@ func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) strin
|
||||
if resolved == "" {
|
||||
return ""
|
||||
}
|
||||
// Preserve thinking suffix from the client's requested model unless config already has one.
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
if thinking.ParseSuffix(resolved).HasSuffix {
|
||||
return resolved
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return resolved + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return resolved
|
||||
return preserveRequestedModelSuffix(requestedModel, resolved)
|
||||
}
|
||||
|
||||
func isAPIKeyAuth(auth *Auth) bool {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
kind, _ := auth.AccountInfo()
|
||||
return strings.EqualFold(strings.TrimSpace(kind), "api_key")
|
||||
}
|
||||
|
||||
func isOpenAICompatAPIKeyAuth(auth *Auth) bool {
|
||||
if !isAPIKeyAuth(auth) {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
||||
return true
|
||||
}
|
||||
if auth.Attributes == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(auth.Attributes["compat_name"]) != ""
|
||||
}
|
||||
|
||||
func openAICompatProviderKey(auth *Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" {
|
||||
return strings.ToLower(providerKey)
|
||||
}
|
||||
if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" {
|
||||
return strings.ToLower(compatName)
|
||||
}
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
}
|
||||
|
||||
func openAICompatModelPoolKey(auth *Auth, requestedModel string) string {
|
||||
base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName)
|
||||
if base == "" {
|
||||
base = strings.TrimSpace(requestedModel)
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base)
|
||||
}
|
||||
|
||||
func (m *Manager) nextModelPoolOffset(key string, size int) int {
|
||||
if m == nil || size <= 1 {
|
||||
return 0
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return 0
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.modelPoolOffsets == nil {
|
||||
m.modelPoolOffsets = make(map[string]int)
|
||||
}
|
||||
offset := m.modelPoolOffsets[key]
|
||||
if offset >= 2_147_483_640 {
|
||||
offset = 0
|
||||
}
|
||||
m.modelPoolOffsets[key] = offset + 1
|
||||
if size <= 0 {
|
||||
return 0
|
||||
}
|
||||
return offset % size
|
||||
}
|
||||
|
||||
func rotateStrings(values []string, offset int) []string {
|
||||
if len(values) <= 1 {
|
||||
return values
|
||||
}
|
||||
if offset <= 0 {
|
||||
out := make([]string, len(values))
|
||||
copy(out, values)
|
||||
return out
|
||||
}
|
||||
offset = offset % len(values)
|
||||
out := make([]string, 0, len(values))
|
||||
out = append(out, values[offset:]...)
|
||||
out = append(out, values[:offset]...)
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string {
|
||||
if m == nil || !isOpenAICompatAPIKeyAuth(auth) {
|
||||
return nil
|
||||
}
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
||||
if cfg == nil {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
providerKey := ""
|
||||
compatName := ""
|
||||
if auth.Attributes != nil {
|
||||
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
|
||||
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
|
||||
}
|
||||
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
||||
}
|
||||
|
||||
func preserveRequestedModelSuffix(requestedModel, resolved string) string {
|
||||
return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel))
|
||||
}
|
||||
|
||||
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
|
||||
return m.prepareExecutionModels(auth, routeModel)
|
||||
}
|
||||
|
||||
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
|
||||
requestedModel := rewriteModelForAuth(routeModel, auth)
|
||||
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
|
||||
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
|
||||
if len(pool) == 1 {
|
||||
return pool
|
||||
}
|
||||
offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool))
|
||||
return rotateStrings(pool, offset)
|
||||
}
|
||||
resolved := m.applyAPIKeyModelAlias(auth, requestedModel)
|
||||
if strings.TrimSpace(resolved) == "" {
|
||||
resolved = requestedModel
|
||||
}
|
||||
return []string{resolved}
|
||||
}
|
||||
|
||||
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for range ch {
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
|
||||
if ch == nil {
|
||||
return nil, true, nil
|
||||
}
|
||||
buffered := make([]cliproxyexecutor.StreamChunk, 0, 1)
|
||||
for {
|
||||
var (
|
||||
chunk cliproxyexecutor.StreamChunk
|
||||
ok bool
|
||||
)
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false, ctx.Err()
|
||||
case chunk, ok = <-ch:
|
||||
}
|
||||
} else {
|
||||
chunk, ok = <-ch
|
||||
}
|
||||
if !ok {
|
||||
return buffered, true, nil
|
||||
}
|
||||
if chunk.Err != nil {
|
||||
return nil, false, chunk.Err
|
||||
}
|
||||
buffered = append(buffered, chunk)
|
||||
if len(chunk.Payload) > 0 {
|
||||
return buffered, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
forward := true
|
||||
emit := func(chunk cliproxyexecutor.StreamChunk) bool {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
return false
|
||||
}
|
||||
if ctx == nil {
|
||||
out <- chunk
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
forward = false
|
||||
return false
|
||||
case out <- chunk:
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, chunk := range buffered {
|
||||
if ok := emit(chunk); !ok {
|
||||
discardStreamChunks(remaining)
|
||||
return
|
||||
}
|
||||
}
|
||||
for chunk := range remaining {
|
||||
if ok := emit(chunk); !ok {
|
||||
discardStreamChunks(remaining)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
|
||||
if executor == nil {
|
||||
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
execModels := m.prepareExecutionModels(auth, routeModel)
|
||||
var lastErr error
|
||||
for idx, execModel := range execModels {
|
||||
execReq := req
|
||||
execReq.Model = execModel
|
||||
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
if errCtx := ctx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(ctx, result)
|
||||
if isRequestInvalidError(errStream) {
|
||||
return nil, errStream
|
||||
}
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
|
||||
buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks)
|
||||
if bootstrapErr != nil {
|
||||
if errCtx := ctx.Err(); errCtx != nil {
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
return nil, errCtx
|
||||
}
|
||||
if isRequestInvalidError(bootstrapErr) {
|
||||
rerr := &Error{Message: bootstrapErr.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||
m.MarkResult(ctx, result)
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
return nil, bootstrapErr
|
||||
}
|
||||
if idx < len(execModels)-1 {
|
||||
rerr := &Error{Message: bootstrapErr.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||
m.MarkResult(ctx, result)
|
||||
discardStreamChunks(streamResult.Chunks)
|
||||
lastErr = bootstrapErr
|
||||
continue
|
||||
}
|
||||
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
|
||||
close(errCh)
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
|
||||
}
|
||||
|
||||
if closed && len(buffered) == 0 {
|
||||
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
|
||||
m.MarkResult(ctx, result)
|
||||
if idx < len(execModels)-1 {
|
||||
lastErr = emptyErr
|
||||
continue
|
||||
}
|
||||
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
|
||||
close(errCh)
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
|
||||
}
|
||||
|
||||
remaining := streamResult.Chunks
|
||||
if closed {
|
||||
closedCh := make(chan cliproxyexecutor.StreamChunk)
|
||||
close(closedCh)
|
||||
remaining = closedCh
|
||||
}
|
||||
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
|
||||
@@ -448,10 +808,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
auth.ID = uuid.NewString()
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
authClone := auth.Clone()
|
||||
m.mu.Lock()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.auths[auth.ID] = authClone
|
||||
m.mu.Unlock()
|
||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(authClone)
|
||||
}
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -473,9 +837,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
}
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
authClone := auth.Clone()
|
||||
m.auths[auth.ID] = authClone
|
||||
m.mu.Unlock()
|
||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(authClone)
|
||||
}
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -484,12 +852,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
// Load resets manager state from the backing store.
|
||||
func (m *Manager) Load(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.store == nil {
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
items, err := m.store.List(ctx)
|
||||
if err != nil {
|
||||
m.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
m.auths = make(map[string]*Auth, len(items))
|
||||
@@ -505,6 +874,8 @@ func (m *Manager) Load(ctx context.Context) error {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
m.rebuildAPIKeyModelAliasLocked(cfg)
|
||||
m.mu.Unlock()
|
||||
m.syncScheduler()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -634,32 +1005,42 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
|
||||
models := m.prepareExecutionModels(auth, routeModel)
|
||||
var authErr error
|
||||
for _, upstreamModel := range models {
|
||||
execReq := req
|
||||
execReq.Model = upstreamModel
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
if isRequestInvalidError(errExec) {
|
||||
return cliproxyexecutor.Response{}, errExec
|
||||
}
|
||||
authErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
if isRequestInvalidError(errExec) {
|
||||
return cliproxyexecutor.Response{}, errExec
|
||||
return resp, nil
|
||||
}
|
||||
if authErr != nil {
|
||||
if isRequestInvalidError(authErr) {
|
||||
return cliproxyexecutor.Response{}, authErr
|
||||
}
|
||||
lastErr = errExec
|
||||
lastErr = authErr
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -696,32 +1077,42 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
|
||||
models := m.prepareExecutionModels(auth, routeModel)
|
||||
var authErr error
|
||||
for _, upstreamModel := range models {
|
||||
execReq := req
|
||||
execReq.Model = upstreamModel
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.hook.OnResult(execCtx, result)
|
||||
if isRequestInvalidError(errExec) {
|
||||
return cliproxyexecutor.Response{}, errExec
|
||||
}
|
||||
authErr = errExec
|
||||
continue
|
||||
}
|
||||
m.hook.OnResult(execCtx, result)
|
||||
if isRequestInvalidError(errExec) {
|
||||
return cliproxyexecutor.Response{}, errExec
|
||||
return resp, nil
|
||||
}
|
||||
if authErr != nil {
|
||||
if isRequestInvalidError(authErr) {
|
||||
return cliproxyexecutor.Response{}, authErr
|
||||
}
|
||||
lastErr = errExec
|
||||
lastErr = authErr
|
||||
continue
|
||||
}
|
||||
m.hook.OnResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -758,63 +1149,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
if isRequestInvalidError(errStream) {
|
||||
return nil, errStream
|
||||
}
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
forward := true
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
continue
|
||||
}
|
||||
if streamCtx == nil {
|
||||
out <- chunk
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
forward = false
|
||||
case out <- chunk:
|
||||
}
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: streamResult.Headers,
|
||||
Chunks: out,
|
||||
}, nil
|
||||
return streamResult, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1245,6 +1591,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
suspendReason := ""
|
||||
clearModelQuota := false
|
||||
setModelQuota := false
|
||||
var authSnapshot *Auth
|
||||
|
||||
m.mu.Lock()
|
||||
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
|
||||
@@ -1338,8 +1685,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
}
|
||||
|
||||
_ = m.persist(ctx, auth)
|
||||
authSnapshot = auth.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if m.scheduler != nil && authSnapshot != nil {
|
||||
m.scheduler.upsertAuth(authSnapshot)
|
||||
}
|
||||
|
||||
if clearModelQuota && result.Model != "" {
|
||||
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
|
||||
@@ -1533,18 +1884,22 @@ func statusCodeFromResult(err *Error) int {
|
||||
}
|
||||
|
||||
// isRequestInvalidError returns true if the error represents a client request
|
||||
// error that should not be retried. Specifically, it checks for 400 Bad Request
|
||||
// with "invalid_request_error" in the message, indicating the request itself is
|
||||
// malformed and switching to a different auth will not help.
|
||||
// error that should not be retried. Specifically, it treats 400 responses with
|
||||
// "invalid_request_error" and all 422 responses as request-shape failures,
|
||||
// where switching auths or pooled upstream models will not help.
|
||||
func isRequestInvalidError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
status := statusCodeFromError(err)
|
||||
if status != http.StatusBadRequest {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return strings.Contains(err.Error(), "invalid_request_error")
|
||||
case http.StatusUnprocessableEntity:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), "invalid_request_error")
|
||||
}
|
||||
|
||||
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
|
||||
@@ -1692,7 +2047,29 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
func (m *Manager) useSchedulerFastPath() bool {
|
||||
if m == nil || m.scheduler == nil {
|
||||
return false
|
||||
}
|
||||
return isBuiltInSelector(m.selector)
|
||||
}
|
||||
|
||||
func shouldRetrySchedulerPick(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var cooldownErr *modelCooldownError
|
||||
if errors.As(err, &cooldownErr) {
|
||||
return true
|
||||
}
|
||||
var authErr *Error
|
||||
if !errors.As(err, &authErr) || authErr == nil {
|
||||
return false
|
||||
}
|
||||
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
|
||||
m.mu.RLock()
|
||||
@@ -1752,7 +2129,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
return authCopy, executor, nil
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
if !m.useSchedulerFastPath() {
|
||||
return m.pickNextLegacy(ctx, provider, model, opts, tried)
|
||||
}
|
||||
executor, okExecutor := m.Executor(provider)
|
||||
if !okExecutor {
|
||||
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
|
||||
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
|
||||
m.syncScheduler()
|
||||
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
|
||||
}
|
||||
if errPick != nil {
|
||||
return nil, nil, errPick
|
||||
}
|
||||
if selected == nil {
|
||||
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||
}
|
||||
authCopy := selected.Clone()
|
||||
if !selected.indexAssigned {
|
||||
m.mu.Lock()
|
||||
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||
current.EnsureIndex()
|
||||
authCopy = current.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return authCopy, executor, nil
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
@@ -1835,6 +2243,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
|
||||
return authCopy, executor, providerKey, nil
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
if !m.useSchedulerFastPath() {
|
||||
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
|
||||
}
|
||||
|
||||
eligibleProviders := make([]string, 0, len(providers))
|
||||
seenProviders := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
providerKey := strings.TrimSpace(strings.ToLower(provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, seen := seenProviders[providerKey]; seen {
|
||||
continue
|
||||
}
|
||||
if _, okExecutor := m.Executor(providerKey); !okExecutor {
|
||||
continue
|
||||
}
|
||||
seenProviders[providerKey] = struct{}{}
|
||||
eligibleProviders = append(eligibleProviders, providerKey)
|
||||
}
|
||||
if len(eligibleProviders) == 0 {
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
|
||||
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
|
||||
m.syncScheduler()
|
||||
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
|
||||
}
|
||||
if errPick != nil {
|
||||
return nil, nil, "", errPick
|
||||
}
|
||||
if selected == nil {
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||
}
|
||||
executor, okExecutor := m.Executor(providerKey)
|
||||
if !okExecutor {
|
||||
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
authCopy := selected.Clone()
|
||||
if !selected.indexAssigned {
|
||||
m.mu.Lock()
|
||||
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||
current.EnsureIndex()
|
||||
authCopy = current.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return authCopy, executor, providerKey, nil
|
||||
}
|
||||
|
||||
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||
if m.store == nil || auth == nil {
|
||||
return nil
|
||||
@@ -2186,6 +2646,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||
current.LastError = &Error{Message: err.Error()}
|
||||
m.auths[id] = current
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(current.Clone())
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return
|
||||
|
||||
163
sdk/cliproxy/auth/conductor_scheduler_refresh_test.go
Normal file
163
sdk/cliproxy/auth/conductor_scheduler_refresh_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type schedulerProviderTestExecutor struct {
|
||||
provider string
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
|
||||
|
||||
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
prime func(*Manager, *Auth) error
|
||||
}{
|
||||
{
|
||||
name: "register",
|
||||
prime: func(manager *Manager, auth *Auth) error {
|
||||
_, errRegister := manager.Register(ctx, auth)
|
||||
return errRegister
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update",
|
||||
prime: func(manager *Manager, auth *Auth) error {
|
||||
_, errRegister := manager.Register(ctx, auth)
|
||||
if errRegister != nil {
|
||||
return errRegister
|
||||
}
|
||||
updated := auth.Clone()
|
||||
updated.Metadata = map[string]any{"updated": true}
|
||||
_, errUpdate := manager.Update(ctx, updated)
|
||||
return errUpdate
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
auth := &Auth{
|
||||
ID: "refresh-entry-" + testCase.name,
|
||||
Provider: "gemini",
|
||||
}
|
||||
if errPrime := testCase.prime(manager, auth); errPrime != nil {
|
||||
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
|
||||
}
|
||||
|
||||
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
|
||||
|
||||
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
|
||||
var authErr *Error
|
||||
if !errors.As(errPick, &authErr) || authErr == nil {
|
||||
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
|
||||
}
|
||||
if authErr.Code != "auth_not_found" {
|
||||
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
|
||||
}
|
||||
|
||||
manager.RefreshSchedulerEntry(auth.ID)
|
||||
|
||||
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() after refresh error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != auth.ID {
|
||||
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
|
||||
|
||||
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
|
||||
|
||||
oldAuth := &Auth{
|
||||
ID: "cooldown-stale-old",
|
||||
Provider: "gemini",
|
||||
}
|
||||
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
|
||||
t.Fatalf("register old auth: %v", errRegister)
|
||||
}
|
||||
|
||||
manager.MarkResult(ctx, Result{
|
||||
AuthID: oldAuth.ID,
|
||||
Provider: "gemini",
|
||||
Model: "scheduler-cooldown-rebuild-model",
|
||||
Success: false,
|
||||
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
|
||||
})
|
||||
|
||||
newAuth := &Auth{
|
||||
ID: "cooldown-stale-new",
|
||||
Provider: "gemini",
|
||||
}
|
||||
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
|
||||
t.Fatalf("register new auth: %v", errRegister)
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(newAuth.ID)
|
||||
})
|
||||
|
||||
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
|
||||
var cooldownErr *modelCooldownError
|
||||
if !errors.As(errPick, &cooldownErr) {
|
||||
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
|
||||
}
|
||||
|
||||
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNext() error = %v", errPick)
|
||||
}
|
||||
if executor == nil {
|
||||
t.Fatal("pickNext() executor = nil")
|
||||
}
|
||||
if got == nil || got.ID != newAuth.ID {
|
||||
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
|
||||
}
|
||||
}
|
||||
@@ -80,54 +80,98 @@ func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string
|
||||
return upstreamModel
|
||||
}
|
||||
|
||||
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
|
||||
func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return ""
|
||||
return thinking.SuffixResult{}, nil
|
||||
}
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
base := requestResult.ModelName
|
||||
if base == "" {
|
||||
base = requestedModel
|
||||
}
|
||||
candidates := []string{base}
|
||||
if base != requestedModel {
|
||||
candidates = append(candidates, requestedModel)
|
||||
}
|
||||
return requestResult, candidates
|
||||
}
|
||||
|
||||
preserveSuffix := func(resolved string) string {
|
||||
resolved = strings.TrimSpace(resolved)
|
||||
if resolved == "" {
|
||||
return ""
|
||||
}
|
||||
if thinking.ParseSuffix(resolved).HasSuffix {
|
||||
return resolved
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return resolved + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string {
|
||||
resolved = strings.TrimSpace(resolved)
|
||||
if resolved == "" {
|
||||
return ""
|
||||
}
|
||||
if thinking.ParseSuffix(resolved).HasSuffix {
|
||||
return resolved
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return resolved + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return resolved
|
||||
}
|
||||
|
||||
func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
requestResult, candidates := modelAliasLookupCandidates(requestedModel)
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]string, 0)
|
||||
seen := make(map[string]struct{})
|
||||
for i := range models {
|
||||
name := strings.TrimSpace(models[i].GetName())
|
||||
alias := strings.TrimSpace(models[i].GetAlias())
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) {
|
||||
continue
|
||||
}
|
||||
if alias != "" && strings.EqualFold(alias, candidate) {
|
||||
if name != "" {
|
||||
return preserveSuffix(name)
|
||||
}
|
||||
return preserveSuffix(candidate)
|
||||
resolved := candidate
|
||||
if name != "" {
|
||||
resolved = name
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return preserveSuffix(name)
|
||||
resolved = preserveResolvedModelSuffix(resolved, requestResult)
|
||||
key := strings.ToLower(strings.TrimSpace(resolved))
|
||||
if key == "" {
|
||||
break
|
||||
}
|
||||
if _, exists := seen[key]; exists {
|
||||
break
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, resolved)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
|
||||
for i := range models {
|
||||
name := strings.TrimSpace(models[i].GetName())
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" || name == "" || !strings.EqualFold(name, candidate) {
|
||||
continue
|
||||
}
|
||||
return []string{preserveResolvedModelSuffix(name, requestResult)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
|
||||
resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models)
|
||||
if len(resolved) > 0 {
|
||||
return resolved[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
419
sdk/cliproxy/auth/openai_compat_pool_test.go
Normal file
419
sdk/cliproxy/auth/openai_compat_pool_test.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type openAICompatPoolExecutor struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
executeModels []string
|
||||
countModels []string
|
||||
streamModels []string
|
||||
executeErrors map[string]error
|
||||
countErrors map[string]error
|
||||
streamFirstErrors map[string]error
|
||||
streamPayloads map[string][]cliproxyexecutor.StreamChunk
|
||||
}
|
||||
|
||||
func (e *openAICompatPoolExecutor) Identifier() string { return e.id }
|
||||
|
||||
func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
_ = ctx
|
||||
_ = auth
|
||||
_ = opts
|
||||
e.mu.Lock()
|
||||
e.executeModels = append(e.executeModels, req.Model)
|
||||
err := e.executeErrors[req.Model]
|
||||
e.mu.Unlock()
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
|
||||
}
|
||||
|
||||
func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
_ = ctx
|
||||
_ = auth
|
||||
_ = opts
|
||||
e.mu.Lock()
|
||||
e.streamModels = append(e.streamModels, req.Model)
|
||||
err := e.streamFirstErrors[req.Model]
|
||||
payloadChunks, hasCustomChunks := e.streamPayloads[req.Model]
|
||||
chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...)
|
||||
e.mu.Unlock()
|
||||
ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks)))
|
||||
if err != nil {
|
||||
ch <- cliproxyexecutor.StreamChunk{Err: err}
|
||||
close(ch)
|
||||
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
|
||||
}
|
||||
if !hasCustomChunks {
|
||||
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)}
|
||||
} else {
|
||||
for _, chunk := range chunks {
|
||||
ch <- chunk
|
||||
}
|
||||
}
|
||||
close(ch)
|
||||
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
_ = ctx
|
||||
_ = auth
|
||||
_ = opts
|
||||
e.mu.Lock()
|
||||
e.countModels = append(e.countModels, req.Model)
|
||||
err := e.countErrors[req.Model]
|
||||
e.mu.Unlock()
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
|
||||
}
|
||||
|
||||
func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||
_ = ctx
|
||||
_ = auth
|
||||
_ = req
|
||||
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
|
||||
}
|
||||
|
||||
func (e *openAICompatPoolExecutor) ExecuteModels() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.executeModels))
|
||||
copy(out, e.executeModels)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *openAICompatPoolExecutor) CountModels() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.countModels))
|
||||
copy(out, e.countModels)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *openAICompatPoolExecutor) StreamModels() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.streamModels))
|
||||
copy(out, e.streamModels)
|
||||
return out
|
||||
}
|
||||
|
||||
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
|
||||
t.Helper()
|
||||
cfg := &internalconfig.Config{
|
||||
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
|
||||
Name: "pool",
|
||||
Models: models,
|
||||
}},
|
||||
}
|
||||
m := NewManager(nil, nil, nil)
|
||||
m.SetConfig(cfg)
|
||||
if executor == nil {
|
||||
executor = &openAICompatPoolExecutor{id: "pool"}
|
||||
}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
auth := &Auth{
|
||||
ID: "pool-auth-" + t.Name(),
|
||||
Provider: "pool",
|
||||
Status: StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"api_key": "test-key",
|
||||
"compat_name": "pool",
|
||||
"provider_key": "pool",
|
||||
},
|
||||
}
|
||||
if _, err := m.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(auth.ID)
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
_, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err == nil || err.Error() != invalidErr.Error() {
|
||||
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
|
||||
}
|
||||
got := executor.CountModels()
|
||||
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||
t.Fatalf("count calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
|
||||
models := []modelAliasEntry{
|
||||
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
|
||||
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
|
||||
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
|
||||
}
|
||||
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
|
||||
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{id: "pool"}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute %d: %v", i, err)
|
||||
}
|
||||
if len(resp.Payload) == 0 {
|
||||
t.Fatalf("execute %d returned empty payload", i)
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
_, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err == nil || err.Error() != invalidErr.Error() {
|
||||
t.Fatalf("execute error = %v, want %v", err, invalidErr)
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||
t.Fatalf("execute calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute: %v", err)
|
||||
}
|
||||
if string(resp.Payload) != "glm-5" {
|
||||
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||
}
|
||||
got := executor.ExecuteModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
|
||||
"qwen3.5-plus": {},
|
||||
},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute stream: %v", err)
|
||||
}
|
||||
var payload []byte
|
||||
for chunk := range streamResult.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||
}
|
||||
payload = append(payload, chunk.Payload...)
|
||||
}
|
||||
if string(payload) != "glm-5" {
|
||||
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
|
||||
}
|
||||
got := executor.StreamModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute stream: %v", err)
|
||||
}
|
||||
var payload []byte
|
||||
for chunk := range streamResult.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||
}
|
||||
payload = append(payload, chunk.Payload...)
|
||||
}
|
||||
if string(payload) != "glm-5" {
|
||||
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
|
||||
}
|
||||
got := executor.StreamModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
|
||||
t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
_, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err == nil || err.Error() != invalidErr.Error() {
|
||||
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
|
||||
}
|
||||
got := executor.StreamModels()
|
||||
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||
t.Fatalf("stream calls = %v, want only first invalid model", got)
|
||||
}
|
||||
}
|
||||
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
executor := &openAICompatPoolExecutor{id: "pool"}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute count %d: %v", i, err)
|
||||
}
|
||||
if len(resp.Payload) == 0 {
|
||||
t.Fatalf("execute count %d returned empty payload", i)
|
||||
}
|
||||
}
|
||||
|
||||
got := executor.CountModels()
|
||||
want := []string{"qwen3.5-plus", "glm-5"}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
|
||||
alias := "claude-opus-4.66"
|
||||
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
|
||||
executor := &openAICompatPoolExecutor{
|
||||
id: "pool",
|
||||
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||
}
|
||||
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||
{Name: "qwen3.5-plus", Alias: alias},
|
||||
{Name: "glm-5", Alias: alias},
|
||||
}, executor)
|
||||
|
||||
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid request error")
|
||||
}
|
||||
if err != invalidErr {
|
||||
t.Fatalf("error = %v, want %v", err, invalidErr)
|
||||
}
|
||||
if streamResult != nil {
|
||||
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
|
||||
}
|
||||
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||
t.Fatalf("stream calls = %v, want only first upstream model", got)
|
||||
}
|
||||
}
|
||||
904
sdk/cliproxy/auth/scheduler.go
Normal file
904
sdk/cliproxy/auth/scheduler.go
Normal file
@@ -0,0 +1,904 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
|
||||
type schedulerStrategy int
|
||||
|
||||
const (
|
||||
schedulerStrategyCustom schedulerStrategy = iota
|
||||
schedulerStrategyRoundRobin
|
||||
schedulerStrategyFillFirst
|
||||
)
|
||||
|
||||
// scheduledState describes how an auth currently participates in a model shard.
|
||||
type scheduledState int
|
||||
|
||||
const (
|
||||
scheduledStateReady scheduledState = iota
|
||||
scheduledStateCooldown
|
||||
scheduledStateBlocked
|
||||
scheduledStateDisabled
|
||||
)
|
||||
|
||||
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
|
||||
type authScheduler struct {
|
||||
mu sync.Mutex
|
||||
strategy schedulerStrategy
|
||||
providers map[string]*providerScheduler
|
||||
authProviders map[string]string
|
||||
mixedCursors map[string]int
|
||||
}
|
||||
|
||||
// providerScheduler stores auth metadata and model shards for a single provider.
|
||||
type providerScheduler struct {
|
||||
providerKey string
|
||||
auths map[string]*scheduledAuthMeta
|
||||
modelShards map[string]*modelScheduler
|
||||
}
|
||||
|
||||
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
|
||||
type scheduledAuthMeta struct {
|
||||
auth *Auth
|
||||
providerKey string
|
||||
priority int
|
||||
virtualParent string
|
||||
websocketEnabled bool
|
||||
supportedModelSet map[string]struct{}
|
||||
}
|
||||
|
||||
// modelScheduler tracks ready and blocked auths for one provider/model combination.
|
||||
type modelScheduler struct {
|
||||
modelKey string
|
||||
entries map[string]*scheduledAuth
|
||||
priorityOrder []int
|
||||
readyByPriority map[int]*readyBucket
|
||||
blocked cooldownQueue
|
||||
}
|
||||
|
||||
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
|
||||
type scheduledAuth struct {
|
||||
meta *scheduledAuthMeta
|
||||
auth *Auth
|
||||
state scheduledState
|
||||
nextRetryAt time.Time
|
||||
}
|
||||
|
||||
// readyBucket keeps the ready views for one priority level.
|
||||
type readyBucket struct {
|
||||
all readyView
|
||||
ws readyView
|
||||
}
|
||||
|
||||
// readyView holds the selection order for flat or grouped round-robin traversal.
|
||||
type readyView struct {
|
||||
flat []*scheduledAuth
|
||||
cursor int
|
||||
parentOrder []string
|
||||
parentCursor int
|
||||
children map[string]*childBucket
|
||||
}
|
||||
|
||||
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
|
||||
type childBucket struct {
|
||||
items []*scheduledAuth
|
||||
cursor int
|
||||
}
|
||||
|
||||
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
|
||||
type cooldownQueue []*scheduledAuth
|
||||
|
||||
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
|
||||
func newAuthScheduler(selector Selector) *authScheduler {
|
||||
return &authScheduler{
|
||||
strategy: selectorStrategy(selector),
|
||||
providers: make(map[string]*providerScheduler),
|
||||
authProviders: make(map[string]string),
|
||||
mixedCursors: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
|
||||
func selectorStrategy(selector Selector) schedulerStrategy {
|
||||
switch selector.(type) {
|
||||
case *FillFirstSelector:
|
||||
return schedulerStrategyFillFirst
|
||||
case nil, *RoundRobinSelector:
|
||||
return schedulerStrategyRoundRobin
|
||||
default:
|
||||
return schedulerStrategyCustom
|
||||
}
|
||||
}
|
||||
|
||||
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
|
||||
func (s *authScheduler) setSelector(selector Selector) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.strategy = selectorStrategy(selector)
|
||||
clear(s.mixedCursors)
|
||||
}
|
||||
|
||||
// rebuild recreates the complete scheduler state from an auth snapshot.
|
||||
func (s *authScheduler) rebuild(auths []*Auth) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.providers = make(map[string]*providerScheduler)
|
||||
s.authProviders = make(map[string]string)
|
||||
s.mixedCursors = make(map[string]int)
|
||||
now := time.Now()
|
||||
for _, auth := range auths {
|
||||
s.upsertAuthLocked(auth, now)
|
||||
}
|
||||
}
|
||||
|
||||
// upsertAuth incrementally synchronizes one auth into the scheduler.
|
||||
func (s *authScheduler) upsertAuth(auth *Auth) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.upsertAuthLocked(auth, time.Now())
|
||||
}
|
||||
|
||||
// removeAuth deletes one auth from every scheduler shard that references it.
|
||||
func (s *authScheduler) removeAuth(authID string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.removeAuthLocked(authID)
|
||||
}
|
||||
|
||||
// pickSingle returns the next auth for a single provider/model request using scheduler state.
|
||||
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
|
||||
if s == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||
modelKey := canonicalModelKey(model)
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
if shard == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
predicate := func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil {
|
||||
return false
|
||||
}
|
||||
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
|
||||
return false
|
||||
}
|
||||
if len(tried) > 0 {
|
||||
if _, ok := tried[entry.auth.ID]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
|
||||
return picked, nil
|
||||
}
|
||||
return nil, shard.unavailableErrorLocked(provider, model, predicate)
|
||||
}
|
||||
|
||||
// pickMixed returns the next auth and provider for a mixed-provider request.
|
||||
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
|
||||
if s == nil {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
normalized := normalizeProviderKeys(providers)
|
||||
if len(normalized) == 0 {
|
||||
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
modelKey := canonicalModelKey(model)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if pinnedAuthID != "" {
|
||||
providerKey := s.authProviders[pinnedAuthID]
|
||||
if providerKey == "" || !containsProvider(normalized, providerKey) {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
predicate := func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
|
||||
return false
|
||||
}
|
||||
if len(tried) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := tried[pinnedAuthID]
|
||||
return !ok
|
||||
}
|
||||
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
|
||||
}
|
||||
|
||||
predicate := triedPredicate(tried)
|
||||
candidateShards := make([]*modelScheduler, len(normalized))
|
||||
bestPriority := 0
|
||||
hasCandidate := false
|
||||
now := time.Now()
|
||||
for providerIndex, providerKey := range normalized {
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, now)
|
||||
candidateShards[providerIndex] = shard
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
|
||||
if !okPriority {
|
||||
continue
|
||||
}
|
||||
if !hasCandidate || priorityReady > bestPriority {
|
||||
bestPriority = priorityReady
|
||||
hasCandidate = true
|
||||
}
|
||||
}
|
||||
if !hasCandidate {
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
if s.strategy == schedulerStrategyFillFirst {
|
||||
for providerIndex, providerKey := range normalized {
|
||||
shard := candidateShards[providerIndex]
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
|
||||
if picked != nil {
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
}
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
|
||||
start := 0
|
||||
if len(normalized) > 0 {
|
||||
start = s.mixedCursors[cursorKey] % len(normalized)
|
||||
}
|
||||
for offset := 0; offset < len(normalized); offset++ {
|
||||
providerIndex := (start + offset) % len(normalized)
|
||||
providerKey := normalized[providerIndex]
|
||||
shard := candidateShards[providerIndex]
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
|
||||
if picked == nil {
|
||||
continue
|
||||
}
|
||||
s.mixedCursors[cursorKey] = providerIndex + 1
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
|
||||
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
|
||||
now := time.Now()
|
||||
total := 0
|
||||
cooldownCount := 0
|
||||
earliest := time.Time{}
|
||||
for _, providerKey := range providers {
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
|
||||
total += localTotal
|
||||
cooldownCount += localCooldownCount
|
||||
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
|
||||
earliest = localEarliest
|
||||
}
|
||||
}
|
||||
if total == 0 {
|
||||
return &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
if cooldownCount == total && !earliest.IsZero() {
|
||||
resetIn := earliest.Sub(now)
|
||||
if resetIn < 0 {
|
||||
resetIn = 0
|
||||
}
|
||||
return newModelCooldownError(model, "", resetIn)
|
||||
}
|
||||
return &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// triedPredicate builds a filter that excludes auths already attempted for the current request.
|
||||
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
|
||||
if len(tried) == 0 {
|
||||
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
|
||||
}
|
||||
return func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := tried[entry.auth.ID]
|
||||
return !ok
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
|
||||
func normalizeProviderKeys(providers []string) []string {
|
||||
seen := make(map[string]struct{}, len(providers))
|
||||
out := make([]string, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[providerKey]; ok {
|
||||
continue
|
||||
}
|
||||
seen[providerKey] = struct{}{}
|
||||
out = append(out, providerKey)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// containsProvider reports whether provider is present in the normalized provider list.
|
||||
func containsProvider(providers []string, provider string) bool {
|
||||
for _, candidate := range providers {
|
||||
if candidate == provider {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
|
||||
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
authID := strings.TrimSpace(auth.ID)
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if authID == "" || providerKey == "" || auth.Disabled {
|
||||
s.removeAuthLocked(authID)
|
||||
return
|
||||
}
|
||||
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
|
||||
if previousState := s.providers[previousProvider]; previousState != nil {
|
||||
previousState.removeAuthLocked(authID)
|
||||
}
|
||||
}
|
||||
meta := buildScheduledAuthMeta(auth)
|
||||
s.authProviders[authID] = providerKey
|
||||
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
|
||||
}
|
||||
|
||||
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
|
||||
func (s *authScheduler) removeAuthLocked(authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
if providerKey := s.authProviders[authID]; providerKey != "" {
|
||||
if providerState := s.providers[providerKey]; providerState != nil {
|
||||
providerState.removeAuthLocked(authID)
|
||||
}
|
||||
delete(s.authProviders, authID)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
|
||||
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
|
||||
if s.providers == nil {
|
||||
s.providers = make(map[string]*providerScheduler)
|
||||
}
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
providerState = &providerScheduler{
|
||||
providerKey: providerKey,
|
||||
auths: make(map[string]*scheduledAuthMeta),
|
||||
modelShards: make(map[string]*modelScheduler),
|
||||
}
|
||||
s.providers[providerKey] = providerState
|
||||
}
|
||||
return providerState
|
||||
}
|
||||
|
||||
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
|
||||
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
virtualParent := ""
|
||||
if auth.Attributes != nil {
|
||||
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
return &scheduledAuthMeta{
|
||||
auth: auth,
|
||||
providerKey: providerKey,
|
||||
priority: authPriority(auth),
|
||||
virtualParent: virtualParent,
|
||||
websocketEnabled: authWebsocketsEnabled(auth),
|
||||
supportedModelSet: supportedModelSetForAuth(auth.ID),
|
||||
}
|
||||
}
|
||||
|
||||
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
|
||||
func supportedModelSetForAuth(authID string) map[string]struct{} {
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return nil
|
||||
}
|
||||
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
set := make(map[string]struct{}, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
modelKey := canonicalModelKey(model.ID)
|
||||
if modelKey == "" {
|
||||
continue
|
||||
}
|
||||
set[modelKey] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
|
||||
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
|
||||
if p == nil || meta == nil || meta.auth == nil {
|
||||
return
|
||||
}
|
||||
p.auths[meta.auth.ID] = meta
|
||||
for modelKey, shard := range p.modelShards {
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
if !meta.supportsModel(modelKey) {
|
||||
shard.removeEntryLocked(meta.auth.ID)
|
||||
continue
|
||||
}
|
||||
shard.upsertEntryLocked(meta, now)
|
||||
}
|
||||
}
|
||||
|
||||
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
|
||||
func (p *providerScheduler) removeAuthLocked(authID string) {
|
||||
if p == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
delete(p.auths, authID)
|
||||
for _, shard := range p.modelShards {
|
||||
if shard != nil {
|
||||
shard.removeEntryLocked(authID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
|
||||
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
modelKey = canonicalModelKey(modelKey)
|
||||
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
|
||||
shard.promoteExpiredLocked(now)
|
||||
return shard
|
||||
}
|
||||
shard := &modelScheduler{
|
||||
modelKey: modelKey,
|
||||
entries: make(map[string]*scheduledAuth),
|
||||
readyByPriority: make(map[int]*readyBucket),
|
||||
}
|
||||
for _, meta := range p.auths {
|
||||
if meta == nil || !meta.supportsModel(modelKey) {
|
||||
continue
|
||||
}
|
||||
shard.upsertEntryLocked(meta, now)
|
||||
}
|
||||
p.modelShards[modelKey] = shard
|
||||
return shard
|
||||
}
|
||||
|
||||
// supportsModel reports whether the auth metadata currently supports modelKey.
|
||||
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
|
||||
modelKey = canonicalModelKey(modelKey)
|
||||
if modelKey == "" {
|
||||
return true
|
||||
}
|
||||
if len(m.supportedModelSet) == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := m.supportedModelSet[modelKey]
|
||||
return ok
|
||||
}
|
||||
|
||||
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
|
||||
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
|
||||
if m == nil || meta == nil || meta.auth == nil {
|
||||
return
|
||||
}
|
||||
entry, ok := m.entries[meta.auth.ID]
|
||||
if !ok || entry == nil {
|
||||
entry = &scheduledAuth{}
|
||||
m.entries[meta.auth.ID] = entry
|
||||
}
|
||||
previousState := entry.state
|
||||
previousNextRetryAt := entry.nextRetryAt
|
||||
previousPriority := 0
|
||||
previousParent := ""
|
||||
previousWebsocketEnabled := false
|
||||
if entry.meta != nil {
|
||||
previousPriority = entry.meta.priority
|
||||
previousParent = entry.meta.virtualParent
|
||||
previousWebsocketEnabled = entry.meta.websocketEnabled
|
||||
}
|
||||
|
||||
entry.meta = meta
|
||||
entry.auth = meta.auth
|
||||
entry.nextRetryAt = time.Time{}
|
||||
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
|
||||
switch {
|
||||
case !blocked:
|
||||
entry.state = scheduledStateReady
|
||||
case reason == blockReasonCooldown:
|
||||
entry.state = scheduledStateCooldown
|
||||
entry.nextRetryAt = next
|
||||
case reason == blockReasonDisabled:
|
||||
entry.state = scheduledStateDisabled
|
||||
default:
|
||||
entry.state = scheduledStateBlocked
|
||||
entry.nextRetryAt = next
|
||||
}
|
||||
|
||||
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
|
||||
return
|
||||
}
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
|
||||
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
|
||||
func (m *modelScheduler) removeEntryLocked(authID string) {
|
||||
if m == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := m.entries[authID]; !ok {
|
||||
return
|
||||
}
|
||||
delete(m.entries, authID)
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
|
||||
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
|
||||
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
|
||||
if m == nil || len(m.blocked) == 0 {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, entry := range m.blocked {
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
|
||||
continue
|
||||
}
|
||||
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
|
||||
switch {
|
||||
case !blocked:
|
||||
entry.state = scheduledStateReady
|
||||
entry.nextRetryAt = time.Time{}
|
||||
case reason == blockReasonCooldown:
|
||||
entry.state = scheduledStateCooldown
|
||||
entry.nextRetryAt = next
|
||||
case reason == blockReasonDisabled:
|
||||
entry.state = scheduledStateDisabled
|
||||
entry.nextRetryAt = time.Time{}
|
||||
default:
|
||||
entry.state = scheduledStateBlocked
|
||||
entry.nextRetryAt = next
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
|
||||
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.promoteExpiredLocked(time.Now())
|
||||
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
|
||||
if !okPriority {
|
||||
return nil
|
||||
}
|
||||
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
|
||||
}
|
||||
|
||||
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
|
||||
// The caller must ensure expired entries are already promoted when needed.
|
||||
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
|
||||
if m == nil {
|
||||
return 0, false
|
||||
}
|
||||
for _, priority := range m.priorityOrder {
|
||||
bucket := m.readyByPriority[priority]
|
||||
if bucket == nil {
|
||||
continue
|
||||
}
|
||||
view := &bucket.all
|
||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||
view = &bucket.ws
|
||||
}
|
||||
if view.pickFirst(predicate) != nil {
|
||||
return priority, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
|
||||
// The caller must ensure expired entries are already promoted when needed.
|
||||
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
bucket := m.readyByPriority[priority]
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
view := &bucket.all
|
||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||
view = &bucket.ws
|
||||
}
|
||||
var picked *scheduledAuth
|
||||
if strategy == schedulerStrategyFillFirst {
|
||||
picked = view.pickFirst(predicate)
|
||||
} else {
|
||||
picked = view.pickRoundRobin(predicate)
|
||||
}
|
||||
if picked == nil || picked.auth == nil {
|
||||
return nil
|
||||
}
|
||||
return picked.auth
|
||||
}
|
||||
|
||||
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
|
||||
now := time.Now()
|
||||
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
|
||||
if total == 0 {
|
||||
return &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
if cooldownCount == total && !earliest.IsZero() {
|
||||
providerForError := provider
|
||||
if providerForError == "mixed" {
|
||||
providerForError = ""
|
||||
}
|
||||
resetIn := earliest.Sub(now)
|
||||
if resetIn < 0 {
|
||||
resetIn = 0
|
||||
}
|
||||
return newModelCooldownError(model, providerForError, resetIn)
|
||||
}
|
||||
return &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
|
||||
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
|
||||
if m == nil {
|
||||
return 0, 0, time.Time{}
|
||||
}
|
||||
total := 0
|
||||
cooldownCount := 0
|
||||
earliest := time.Time{}
|
||||
for _, entry := range m.entries {
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
total++
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
if entry.state != scheduledStateCooldown {
|
||||
continue
|
||||
}
|
||||
cooldownCount++
|
||||
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
|
||||
earliest = entry.nextRetryAt
|
||||
}
|
||||
}
|
||||
return total, cooldownCount, earliest
|
||||
}
|
||||
|
||||
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
|
||||
func (m *modelScheduler) rebuildIndexesLocked() {
|
||||
m.readyByPriority = make(map[int]*readyBucket)
|
||||
m.priorityOrder = m.priorityOrder[:0]
|
||||
m.blocked = m.blocked[:0]
|
||||
priorityBuckets := make(map[int][]*scheduledAuth)
|
||||
for _, entry := range m.entries {
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
switch entry.state {
|
||||
case scheduledStateReady:
|
||||
priority := entry.meta.priority
|
||||
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
|
||||
case scheduledStateCooldown, scheduledStateBlocked:
|
||||
m.blocked = append(m.blocked, entry)
|
||||
}
|
||||
}
|
||||
for priority, entries := range priorityBuckets {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].auth.ID < entries[j].auth.ID
|
||||
})
|
||||
m.readyByPriority[priority] = buildReadyBucket(entries)
|
||||
m.priorityOrder = append(m.priorityOrder, priority)
|
||||
}
|
||||
sort.Slice(m.priorityOrder, func(i, j int) bool {
|
||||
return m.priorityOrder[i] > m.priorityOrder[j]
|
||||
})
|
||||
sort.Slice(m.blocked, func(i, j int) bool {
|
||||
left := m.blocked[i]
|
||||
right := m.blocked[j]
|
||||
if left == nil || right == nil {
|
||||
return left != nil
|
||||
}
|
||||
if left.nextRetryAt.Equal(right.nextRetryAt) {
|
||||
return left.auth.ID < right.auth.ID
|
||||
}
|
||||
if left.nextRetryAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
if right.nextRetryAt.IsZero() {
|
||||
return true
|
||||
}
|
||||
return left.nextRetryAt.Before(right.nextRetryAt)
|
||||
})
|
||||
}
|
||||
|
||||
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
|
||||
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
|
||||
bucket := &readyBucket{}
|
||||
bucket.all = buildReadyView(entries)
|
||||
wsEntries := make([]*scheduledAuth, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
|
||||
wsEntries = append(wsEntries, entry)
|
||||
}
|
||||
}
|
||||
bucket.ws = buildReadyView(wsEntries)
|
||||
return bucket
|
||||
}
|
||||
|
||||
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
|
||||
func buildReadyView(entries []*scheduledAuth) readyView {
|
||||
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
|
||||
if len(entries) == 0 {
|
||||
return view
|
||||
}
|
||||
groups := make(map[string][]*scheduledAuth)
|
||||
for _, entry := range entries {
|
||||
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
|
||||
return view
|
||||
}
|
||||
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
|
||||
}
|
||||
if len(groups) <= 1 {
|
||||
return view
|
||||
}
|
||||
view.children = make(map[string]*childBucket, len(groups))
|
||||
view.parentOrder = make([]string, 0, len(groups))
|
||||
for parent := range groups {
|
||||
view.parentOrder = append(view.parentOrder, parent)
|
||||
}
|
||||
sort.Strings(view.parentOrder)
|
||||
for _, parent := range view.parentOrder {
|
||||
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
|
||||
}
|
||||
return view
|
||||
}
|
||||
|
||||
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
|
||||
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
for _, entry := range v.flat {
|
||||
if predicate == nil || predicate(entry) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
|
||||
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
if len(v.parentOrder) > 1 && len(v.children) > 0 {
|
||||
return v.pickGroupedRoundRobin(predicate)
|
||||
}
|
||||
if len(v.flat) == 0 {
|
||||
return nil
|
||||
}
|
||||
start := 0
|
||||
if len(v.flat) > 0 {
|
||||
start = v.cursor % len(v.flat)
|
||||
}
|
||||
for offset := 0; offset < len(v.flat); offset++ {
|
||||
index := (start + offset) % len(v.flat)
|
||||
entry := v.flat[index]
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
v.cursor = index + 1
|
||||
return entry
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
|
||||
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
start := 0
|
||||
if len(v.parentOrder) > 0 {
|
||||
start = v.parentCursor % len(v.parentOrder)
|
||||
}
|
||||
for offset := 0; offset < len(v.parentOrder); offset++ {
|
||||
parentIndex := (start + offset) % len(v.parentOrder)
|
||||
parent := v.parentOrder[parentIndex]
|
||||
child := v.children[parent]
|
||||
if child == nil || len(child.items) == 0 {
|
||||
continue
|
||||
}
|
||||
itemStart := child.cursor % len(child.items)
|
||||
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
|
||||
itemIndex := (itemStart + itemOffset) % len(child.items)
|
||||
entry := child.items[itemIndex]
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
child.cursor = itemIndex + 1
|
||||
v.parentCursor = parentIndex + 1
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
216
sdk/cliproxy/auth/scheduler_benchmark_test.go
Normal file
216
sdk/cliproxy/auth/scheduler_benchmark_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type schedulerBenchmarkExecutor struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) Identifier() string { return e.id }
|
||||
|
||||
func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) {
|
||||
b.Helper()
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
providers := []string{"gemini"}
|
||||
manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"}
|
||||
if mixed {
|
||||
providers = []string{"gemini", "claude"}
|
||||
manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"}
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
model := "bench-model"
|
||||
for index := 0; index < total; index++ {
|
||||
provider := providers[0]
|
||||
if mixed && index%2 == 1 {
|
||||
provider = providers[1]
|
||||
}
|
||||
auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider}
|
||||
if withPriority {
|
||||
priority := "0"
|
||||
if index%2 == 0 {
|
||||
priority = "10"
|
||||
}
|
||||
auth.Attributes = map[string]string{"priority": priority}
|
||||
}
|
||||
_, errRegister := manager.Register(context.Background(), auth)
|
||||
if errRegister != nil {
|
||||
b.Fatalf("Register(%s) error = %v", auth.ID, errRegister)
|
||||
}
|
||||
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}})
|
||||
}
|
||||
manager.syncScheduler()
|
||||
b.Cleanup(func() {
|
||||
for index := 0; index < total; index++ {
|
||||
provider := providers[0]
|
||||
if mixed && index%2 == 1 {
|
||||
provider = providers[1]
|
||||
}
|
||||
reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index))
|
||||
}
|
||||
})
|
||||
|
||||
return manager, providers, model
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNext500(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 500, false, false)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNext1000(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextPriority500(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 500, false, true)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextPriority1000(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 1000, false, true)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextMixed500(b *testing.B) {
|
||||
manager, providers, model := benchmarkManagerSetup(b, 500, true, false)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil || provider == "" {
|
||||
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
|
||||
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil || provider == "" {
|
||||
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||
if errPick != nil || auth == nil {
|
||||
b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick)
|
||||
}
|
||||
manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true})
|
||||
}
|
||||
}
|
||||
503
sdk/cliproxy/auth/scheduler_test.go
Normal file
503
sdk/cliproxy/auth/scheduler_test.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type schedulerTestExecutor struct{}
|
||||
|
||||
func (schedulerTestExecutor) Identifier() string { return "test" }
|
||||
|
||||
func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type trackingSelector struct {
|
||||
calls int
|
||||
lastAuthID []string
|
||||
}
|
||||
|
||||
func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
s.calls++
|
||||
s.lastAuthID = s.lastAuthID[:0]
|
||||
for _, auth := range auths {
|
||||
s.lastAuthID = append(s.lastAuthID, auth.ID)
|
||||
}
|
||||
if len(auths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return auths[len(auths)-1], nil
|
||||
}
|
||||
|
||||
func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler {
|
||||
scheduler := newAuthScheduler(selector)
|
||||
scheduler.rebuild(auths)
|
||||
return scheduler
|
||||
}
|
||||
|
||||
func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) {
|
||||
t.Helper()
|
||||
reg := registry.GetGlobalRegistry()
|
||||
for _, authID := range authIDs {
|
||||
reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}})
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
for _, authID := range authIDs {
|
||||
reg.UnregisterClient(authID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}},
|
||||
&Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
|
||||
&Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
|
||||
)
|
||||
|
||||
want := []string{"high-a", "high-b", "high-a"}
|
||||
for index, wantID := range want {
|
||||
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != wantID {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&FillFirstSelector{},
|
||||
&Auth{ID: "b", Provider: "gemini"},
|
||||
&Auth{ID: "a", Provider: "gemini"},
|
||||
&Auth{ID: "c", Provider: "gemini"},
|
||||
)
|
||||
|
||||
for index := 0; index < 3; index++ {
|
||||
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != "a" {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := "gemini-2.5-pro"
|
||||
registerSchedulerModels(t, "gemini", model, "cooldown-expired")
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{
|
||||
ID: "cooldown-expired",
|
||||
Provider: "gemini",
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusError,
|
||||
Unavailable: true,
|
||||
NextRetryAfter: time.Now().Add(-1 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() error = %v", errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() auth = nil")
|
||||
}
|
||||
if got.ID != "cooldown-expired" {
|
||||
t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2")
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
|
||||
&Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
|
||||
&Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
|
||||
&Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
|
||||
)
|
||||
|
||||
wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"}
|
||||
wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"}
|
||||
for index := range wantIDs {
|
||||
got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
if got.Attributes["gemini_virtual_parent"] != wantParents[index] {
|
||||
t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "codex-http", Provider: "codex"},
|
||||
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
|
||||
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
|
||||
)
|
||||
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
|
||||
for index, wantID := range want {
|
||||
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||
}
|
||||
if got.ID != wantID {
|
||||
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "gemini-a", Provider: "gemini"},
|
||||
&Auth{ID: "gemini-b", Provider: "gemini"},
|
||||
&Auth{ID: "claude-a", Provider: "claude"},
|
||||
)
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
for index := range wantProviders {
|
||||
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickMixed() #%d auth = nil", index)
|
||||
}
|
||||
if provider != wantProviders[index] {
|
||||
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := "gpt-default"
|
||||
registerSchedulerModels(t, "provider-low", model, "low")
|
||||
registerSchedulerModels(t, "provider-high-a", model, "high-a")
|
||||
registerSchedulerModels(t, "provider-high-b", model, "high-b")
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
|
||||
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
|
||||
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
|
||||
)
|
||||
|
||||
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
|
||||
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
|
||||
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
|
||||
for index := range wantProviders {
|
||||
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickMixed() #%d auth = nil", index)
|
||||
}
|
||||
if provider != wantProviders[index] {
|
||||
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||
manager.executors["claude"] = schedulerTestExecutor{}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-b) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
for index := range wantProviders {
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickNextMixed() #%d auth = nil", index)
|
||||
}
|
||||
if provider != wantProviders[index] {
|
||||
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &trackingSelector{}
|
||||
manager := NewManager(nil, selector, nil)
|
||||
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||
manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"}
|
||||
manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"}
|
||||
|
||||
got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNext() error = %v", errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickNext() auth = nil")
|
||||
}
|
||||
if selector.calls != 1 {
|
||||
t.Fatalf("selector.calls = %d, want %d", selector.calls, 1)
|
||||
}
|
||||
if len(selector.lastAuthID) != 2 {
|
||||
t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2)
|
||||
}
|
||||
if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] {
|
||||
t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
if manager.scheduler == nil {
|
||||
t.Fatalf("manager.scheduler = nil")
|
||||
}
|
||||
if manager.scheduler.strategy != schedulerStrategyRoundRobin {
|
||||
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin)
|
||||
}
|
||||
|
||||
manager.SetSelector(&FillFirstSelector{})
|
||||
if manager.scheduler.strategy != schedulerStrategyFillFirst {
|
||||
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(auth-b) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(auth-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("scheduler.pickSingle() error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != "auth-a" {
|
||||
t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got)
|
||||
}
|
||||
|
||||
if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil {
|
||||
t.Fatalf("Update(auth-a) error = %v", errUpdate)
|
||||
}
|
||||
|
||||
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("scheduler.pickSingle() after update error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != "auth-b" {
|
||||
t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||
manager.executors["claude"] = schedulerTestExecutor{}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-b) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
for index := range wantProviders {
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickNextMixed() #%d auth = nil", index)
|
||||
}
|
||||
if provider != wantProviders[index] {
|
||||
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.executors["claude"] = schedulerTestExecutor{}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNextMixed() error = %v", errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickNextMixed() auth = nil")
|
||||
}
|
||||
if provider != "claude" {
|
||||
t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude")
|
||||
}
|
||||
if got.ID != "claude-a" {
|
||||
t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
|
||||
reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient("auth-a")
|
||||
reg.UnregisterClient("auth-b")
|
||||
})
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(auth-a) error = %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
|
||||
t.Fatalf("Register(auth-b) error = %v", errRegister)
|
||||
}
|
||||
|
||||
manager.MarkResult(context.Background(), Result{
|
||||
AuthID: "auth-a",
|
||||
Provider: "gemini",
|
||||
Model: "test-model",
|
||||
Success: false,
|
||||
Error: &Error{HTTPStatus: 429, Message: "quota"},
|
||||
})
|
||||
|
||||
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != "auth-b" {
|
||||
t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got)
|
||||
}
|
||||
|
||||
manager.MarkResult(context.Background(), Result{
|
||||
AuthID: "auth-a",
|
||||
Provider: "gemini",
|
||||
Model: "test-model",
|
||||
Success: true,
|
||||
})
|
||||
|
||||
seen := make(map[string]struct{}, 2)
|
||||
for index := 0; index < 2; index++ {
|
||||
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index)
|
||||
}
|
||||
seen[got.ID] = struct{}{}
|
||||
}
|
||||
if len(seen) != 2 {
|
||||
t.Fatalf("len(seen) = %d, want %d", len(seen), 2)
|
||||
}
|
||||
}
|
||||
@@ -390,6 +390,27 @@ func (a *Auth) AccountInfo() (string, string) {
|
||||
|
||||
// Check metadata for email first (OAuth-style auth)
|
||||
if a.Metadata != nil {
|
||||
if method, ok := a.Metadata["auth_method"].(string); ok {
|
||||
switch strings.ToLower(strings.TrimSpace(method)) {
|
||||
case "oauth":
|
||||
for _, key := range []string{"email", "username", "name"} {
|
||||
if value, okValue := a.Metadata[key].(string); okValue {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return "oauth", trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
case "pat", "personal_access_token":
|
||||
for _, key := range []string{"username", "email", "name", "token_preview"} {
|
||||
if value, okValue := a.Metadata[key].(string); okValue {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return "personal_access_token", trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return "personal_access_token", ""
|
||||
}
|
||||
}
|
||||
if v, ok := a.Metadata["email"].(string); ok {
|
||||
email := strings.TrimSpace(v)
|
||||
if email != "" {
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on
|
||||
@@ -39,35 +36,12 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http.
|
||||
if rt != nil {
|
||||
return rt
|
||||
}
|
||||
// Parse the proxy URL to determine the scheme.
|
||||
proxyURL, errParse := url.Parse(proxyStr)
|
||||
if errParse != nil {
|
||||
log.Errorf("parse proxy URL failed: %v", errParse)
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||
if errBuild != nil {
|
||||
log.Errorf("%v", errBuild)
|
||||
return nil
|
||||
}
|
||||
var transport *http.Transport
|
||||
// Handle different proxy schemes.
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
// Configure SOCKS5 proxy with optional authentication.
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth := &proxy.Auth{User: username, Password: password}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return nil
|
||||
}
|
||||
// Set up a custom transport using the SOCKS5 dialer.
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
// Configure HTTP or HTTPS proxy.
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
} else {
|
||||
log.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
if transport == nil {
|
||||
return nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
|
||||
22
sdk/cliproxy/rtprovider_test.go
Normal file
22
sdk/cliproxy/rtprovider_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestRoundTripperForDirectBypassesProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := newDefaultRoundTripperProvider()
|
||||
rt := provider.RoundTripperFor(&coreauth.Auth{ProxyURL: "direct"})
|
||||
transport, ok := rt.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", rt)
|
||||
}
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("expected direct transport to disable proxy function")
|
||||
}
|
||||
}
|
||||
@@ -119,6 +119,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewCodexAuthenticator(),
|
||||
sdkAuth.NewClaudeAuthenticator(),
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -293,8 +294,6 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
||||
// IMPORTANT: Update coreManager FIRST, before model registration.
|
||||
// This ensures that configuration changes (proxy_url, prefix, etc.) take effect
|
||||
// immediately for API calls, rather than waiting for model registration to complete.
|
||||
// Model registration may involve network calls (e.g., FetchAntigravityModels) that
|
||||
// could timeout if the new proxy_url is unreachable.
|
||||
op := "register"
|
||||
var err error
|
||||
if existing, ok := s.coreManager.GetByID(auth.ID); ok {
|
||||
@@ -323,6 +322,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
||||
// This operation may block on network calls, but the auth configuration
|
||||
// is already effective at this point.
|
||||
s.registerModelsForAuth(auth)
|
||||
|
||||
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
|
||||
// from the now-populated global model registry. Without this, newly added auths
|
||||
// have an empty supportedModelSet (because Register/Update upserts into the
|
||||
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
|
||||
s.coreManager.RefreshSchedulerEntry(auth.ID)
|
||||
}
|
||||
|
||||
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||
@@ -438,6 +443,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
|
||||
case "github-copilot":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||
case "gitlab":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitLabExecutor(s.cfg))
|
||||
default:
|
||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
if providerKey == "" {
|
||||
@@ -447,6 +454,17 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) registerResolvedModelsForAuth(a *coreauth.Auth, providerKey string, models []*ModelInfo) {
|
||||
if a == nil || a.ID == "" {
|
||||
return
|
||||
}
|
||||
if len(models) == 0 {
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
return
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, models)
|
||||
}
|
||||
|
||||
// rebindExecutors refreshes provider executors so they observe the latest configuration.
|
||||
func (s *Service) rebindExecutors() {
|
||||
if s == nil || s.coreManager == nil {
|
||||
@@ -554,6 +572,44 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.hooks.OnBeforeStart(s.cfg)
|
||||
}
|
||||
|
||||
// Register callback for startup and periodic model catalog refresh.
|
||||
// When remote model definitions change, re-register models for affected providers.
|
||||
// This intentionally rebuilds per-auth model availability from the latest catalog
|
||||
// snapshot instead of preserving prior registry suppression state.
|
||||
registry.SetModelRefreshCallback(func(changedProviders []string) {
|
||||
if s == nil || s.coreManager == nil || len(changedProviders) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
providerSet := make(map[string]bool, len(changedProviders))
|
||||
for _, p := range changedProviders {
|
||||
providerSet[strings.ToLower(strings.TrimSpace(p))] = true
|
||||
}
|
||||
|
||||
auths := s.coreManager.List()
|
||||
refreshed := 0
|
||||
for _, item := range auths {
|
||||
if item == nil || item.ID == "" {
|
||||
continue
|
||||
}
|
||||
auth, ok := s.coreManager.GetByID(item.ID)
|
||||
if !ok || auth == nil || auth.Disabled {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if !providerSet[provider] {
|
||||
continue
|
||||
}
|
||||
if s.refreshModelRegistrationForAuth(auth) {
|
||||
refreshed++
|
||||
}
|
||||
}
|
||||
|
||||
if refreshed > 0 {
|
||||
log.Infof("re-registered models for %d auth(s) due to model catalog changes: %v", refreshed, changedProviders)
|
||||
}
|
||||
})
|
||||
|
||||
s.serverErr = make(chan error, 1)
|
||||
go func() {
|
||||
if errStart := s.server.Start(); errStart != nil {
|
||||
@@ -836,9 +892,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
models = registry.GetAIStudioModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "antigravity":
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
models = executor.FetchAntigravityModels(ctx, a, s.cfg)
|
||||
cancel()
|
||||
models = registry.GetAntigravityModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "claude":
|
||||
models = registry.GetClaudeModels()
|
||||
@@ -852,7 +906,22 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
}
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "codex":
|
||||
models = registry.GetOpenAIModels()
|
||||
codexPlanType := ""
|
||||
if a.Attributes != nil {
|
||||
codexPlanType = strings.TrimSpace(a.Attributes["plan_type"])
|
||||
}
|
||||
switch strings.ToLower(codexPlanType) {
|
||||
case "pro":
|
||||
models = registry.GetCodexProModels()
|
||||
case "plus":
|
||||
models = registry.GetCodexPlusModels()
|
||||
case "team":
|
||||
models = registry.GetCodexTeamModels()
|
||||
case "free":
|
||||
models = registry.GetCodexFreeModels()
|
||||
default:
|
||||
models = registry.GetCodexProModels()
|
||||
}
|
||||
if entry := s.resolveConfigCodexKey(a); entry != nil {
|
||||
if len(entry.Models) > 0 {
|
||||
models = buildCodexConfigModels(entry)
|
||||
@@ -870,7 +939,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "kimi":
|
||||
models = registry.GetKimiModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "github-copilot":
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
@@ -882,6 +951,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "kilo":
|
||||
models = executor.FetchKiloModels(context.Background(), a, s.cfg)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "gitlab":
|
||||
models = executor.GitLabModelsFromAuth(a)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
if s.cfg != nil {
|
||||
@@ -949,7 +1021,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if providerKey == "" {
|
||||
providerKey = "openai-compatibility"
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
|
||||
s.registerResolvedModelsForAuth(a, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
|
||||
} else {
|
||||
// Ensure stale registrations are cleared when model list becomes empty.
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -970,16 +1042,60 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if key == "" {
|
||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
if provider == "antigravity" {
|
||||
s.backfillAntigravityModels(a, models)
|
||||
}
|
||||
s.registerResolvedModelsForAuth(a, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
return
|
||||
}
|
||||
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
}
|
||||
|
||||
// refreshModelRegistrationForAuth re-applies the latest model registration for
|
||||
// one auth and reconciles any concurrent auth changes that race with the
|
||||
// refresh. Callers are expected to pre-filter provider membership.
|
||||
//
|
||||
// Re-registration is deliberate: registry cooldown/suspension state is treated
|
||||
// as part of the previous registration snapshot and is cleared when the auth is
|
||||
// rebound to the refreshed model catalog.
|
||||
func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
|
||||
if s == nil || s.coreManager == nil || current == nil || current.ID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if !current.Disabled {
|
||||
s.ensureExecutorsForAuth(current)
|
||||
}
|
||||
s.registerModelsForAuth(current)
|
||||
|
||||
latest, ok := s.latestAuthForModelRegistration(current.ID)
|
||||
if !ok || latest.Disabled {
|
||||
GlobalModelRegistry().UnregisterClient(current.ID)
|
||||
s.coreManager.RefreshSchedulerEntry(current.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
// Re-apply the latest auth snapshot so concurrent auth updates cannot leave
|
||||
// stale model registrations behind. This may duplicate registration work when
|
||||
// no auth fields changed, but keeps the refresh path simple and correct.
|
||||
s.ensureExecutorsForAuth(latest)
|
||||
s.registerModelsForAuth(latest)
|
||||
s.coreManager.RefreshSchedulerEntry(current.ID)
|
||||
return true
|
||||
}
|
||||
|
||||
// latestAuthForModelRegistration returns the latest auth snapshot regardless of
|
||||
// provider membership. Callers use this after a registration attempt to restore
|
||||
// whichever state currently owns the client ID in the global registry.
|
||||
func (s *Service) latestAuthForModelRegistration(authID string) (*coreauth.Auth, bool) {
|
||||
if s == nil || s.coreManager == nil || authID == "" {
|
||||
return nil, false
|
||||
}
|
||||
auth, ok := s.coreManager.GetByID(authID)
|
||||
if !ok || auth == nil || auth.ID == "" {
|
||||
return nil, false
|
||||
}
|
||||
return auth, true
|
||||
}
|
||||
|
||||
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
|
||||
if auth == nil || s.cfg == nil {
|
||||
return nil
|
||||
@@ -1118,56 +1234,6 @@ func (s *Service) oauthExcludedModels(provider, authKind string) []string {
|
||||
return cfg.OAuthExcludedModels[providerKey]
|
||||
}
|
||||
|
||||
func (s *Service) backfillAntigravityModels(source *coreauth.Auth, primaryModels []*ModelInfo) {
|
||||
if s == nil || s.coreManager == nil || len(primaryModels) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sourceID := ""
|
||||
if source != nil {
|
||||
sourceID = strings.TrimSpace(source.ID)
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
for _, candidate := range s.coreManager.List() {
|
||||
if candidate == nil || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
candidateID := strings.TrimSpace(candidate.ID)
|
||||
if candidateID == "" || candidateID == sourceID {
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(candidate.Provider), "antigravity") {
|
||||
continue
|
||||
}
|
||||
if len(reg.GetModelsForClient(candidateID)) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
authKind := strings.ToLower(strings.TrimSpace(candidate.Attributes["auth_kind"]))
|
||||
if authKind == "" {
|
||||
if kind, _ := candidate.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
excluded := s.oauthExcludedModels("antigravity", authKind)
|
||||
if candidate.Attributes != nil {
|
||||
if val, ok := candidate.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
|
||||
excluded = strings.Split(val, ",")
|
||||
}
|
||||
}
|
||||
|
||||
models := applyExcludedModels(primaryModels, excluded)
|
||||
models = applyOAuthModelAlias(s.cfg, "antigravity", authKind, models)
|
||||
if len(models) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
reg.RegisterClient(candidateID, "antigravity", applyModelPrefixes(models, candidate.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
log.Debugf("antigravity models backfilled for auth %s using primary model list", candidateID)
|
||||
}
|
||||
}
|
||||
|
||||
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||
if len(models) == 0 || len(excluded) == 0 {
|
||||
return models
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestBackfillAntigravityModels_RegistersMissingAuth(t *testing.T) {
|
||||
source := &coreauth.Auth{
|
||||
ID: "ag-backfill-source",
|
||||
Provider: "antigravity",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
}
|
||||
target := &coreauth.Auth{
|
||||
ID: "ag-backfill-target",
|
||||
Provider: "antigravity",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
if _, err := manager.Register(context.Background(), source); err != nil {
|
||||
t.Fatalf("register source auth: %v", err)
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), target); err != nil {
|
||||
t.Fatalf("register target auth: %v", err)
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: manager,
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(source.ID)
|
||||
reg.UnregisterClient(target.ID)
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(source.ID)
|
||||
reg.UnregisterClient(target.ID)
|
||||
})
|
||||
|
||||
primary := []*ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
reg.RegisterClient(source.ID, "antigravity", primary)
|
||||
|
||||
service.backfillAntigravityModels(source, primary)
|
||||
|
||||
got := reg.GetModelsForClient(target.ID)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected target auth to be backfilled with 2 models, got %d", len(got))
|
||||
}
|
||||
|
||||
ids := make(map[string]struct{}, len(got))
|
||||
for _, model := range got {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
ids[strings.ToLower(strings.TrimSpace(model.ID))] = struct{}{}
|
||||
}
|
||||
if _, ok := ids["claude-sonnet-4-5"]; !ok {
|
||||
t.Fatal("expected backfilled model claude-sonnet-4-5")
|
||||
}
|
||||
if _, ok := ids["gemini-2.5-pro"]; !ok {
|
||||
t.Fatal("expected backfilled model gemini-2.5-pro")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackfillAntigravityModels_RespectsExcludedModels(t *testing.T) {
|
||||
source := &coreauth.Auth{
|
||||
ID: "ag-backfill-source-excluded",
|
||||
Provider: "antigravity",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
}
|
||||
target := &coreauth.Auth{
|
||||
ID: "ag-backfill-target-excluded",
|
||||
Provider: "antigravity",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
"excluded_models": "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
if _, err := manager.Register(context.Background(), source); err != nil {
|
||||
t.Fatalf("register source auth: %v", err)
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), target); err != nil {
|
||||
t.Fatalf("register target auth: %v", err)
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: manager,
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(source.ID)
|
||||
reg.UnregisterClient(target.ID)
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(source.ID)
|
||||
reg.UnregisterClient(target.ID)
|
||||
})
|
||||
|
||||
primary := []*ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
reg.RegisterClient(source.ID, "antigravity", primary)
|
||||
|
||||
service.backfillAntigravityModels(source, primary)
|
||||
|
||||
got := reg.GetModelsForClient(target.ID)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 model after exclusion, got %d", len(got))
|
||||
}
|
||||
if got[0] == nil || !strings.EqualFold(strings.TrimSpace(got[0].ID), "claude-sonnet-4-5") {
|
||||
t.Fatalf("expected remaining model %q, got %+v", "claude-sonnet-4-5", got[0])
|
||||
}
|
||||
}
|
||||
48
sdk/cliproxy/service_gitlab_models_test.go
Normal file
48
sdk/cliproxy/service_gitlab_models_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestRegisterModelsForAuth_GitLabUsesDiscoveredModels(t *testing.T) {
|
||||
service := &Service{cfg: &config.Config{}}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-auth.json",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(auth.ID)
|
||||
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
|
||||
|
||||
service.registerModelsForAuth(auth)
|
||||
models := reg.GetModelsForClient(auth.ID)
|
||||
if len(models) < 2 {
|
||||
t.Fatalf("expected stable alias and discovered model, got %d entries", len(models))
|
||||
}
|
||||
|
||||
seenAlias := false
|
||||
seenDiscovered := false
|
||||
for _, model := range models {
|
||||
switch model.ID {
|
||||
case "gitlab-duo":
|
||||
seenAlias = true
|
||||
case "claude-sonnet-4-5":
|
||||
seenDiscovered = true
|
||||
}
|
||||
}
|
||||
if !seenAlias || !seenDiscovered {
|
||||
t.Fatalf("expected gitlab-duo and discovered model, got %+v", models)
|
||||
}
|
||||
}
|
||||
139
sdk/proxyutil/proxy.go
Normal file
139
sdk/proxyutil/proxy.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package proxyutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// Mode describes how a proxy setting should be interpreted.
|
||||
type Mode int
|
||||
|
||||
const (
|
||||
// ModeInherit means no explicit proxy behavior was configured.
|
||||
ModeInherit Mode = iota
|
||||
// ModeDirect means outbound requests must bypass proxies explicitly.
|
||||
ModeDirect
|
||||
// ModeProxy means a concrete proxy URL was configured.
|
||||
ModeProxy
|
||||
// ModeInvalid means the proxy setting is present but malformed or unsupported.
|
||||
ModeInvalid
|
||||
)
|
||||
|
||||
// Setting is the normalized interpretation of a proxy configuration value.
|
||||
type Setting struct {
|
||||
Raw string
|
||||
Mode Mode
|
||||
URL *url.URL
|
||||
}
|
||||
|
||||
// Parse normalizes a proxy configuration value into inherit, direct, or proxy modes.
|
||||
func Parse(raw string) (Setting, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
setting := Setting{Raw: trimmed}
|
||||
|
||||
if trimmed == "" {
|
||||
setting.Mode = ModeInherit
|
||||
return setting, nil
|
||||
}
|
||||
|
||||
if strings.EqualFold(trimmed, "direct") || strings.EqualFold(trimmed, "none") {
|
||||
setting.Mode = ModeDirect
|
||||
return setting, nil
|
||||
}
|
||||
|
||||
parsedURL, errParse := url.Parse(trimmed)
|
||||
if errParse != nil {
|
||||
setting.Mode = ModeInvalid
|
||||
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
|
||||
}
|
||||
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
setting.Mode = ModeInvalid
|
||||
return setting, fmt.Errorf("proxy URL missing scheme/host")
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "socks5", "http", "https":
|
||||
setting.Mode = ModeProxy
|
||||
setting.URL = parsedURL
|
||||
return setting, nil
|
||||
default:
|
||||
setting.Mode = ModeInvalid
|
||||
return setting, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
// NewDirectTransport returns a transport that bypasses environment proxies.
|
||||
func NewDirectTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
|
||||
clone := transport.Clone()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
return &http.Transport{Proxy: nil}
|
||||
}
|
||||
|
||||
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
|
||||
func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
|
||||
setting, errParse := Parse(raw)
|
||||
if errParse != nil {
|
||||
return nil, setting.Mode, errParse
|
||||
}
|
||||
|
||||
switch setting.Mode {
|
||||
case ModeInherit:
|
||||
return nil, setting.Mode, nil
|
||||
case ModeDirect:
|
||||
return NewDirectTransport(), setting.Mode, nil
|
||||
case ModeProxy:
|
||||
if setting.URL.Scheme == "socks5" {
|
||||
var proxyAuth *proxy.Auth
|
||||
if setting.URL.User != nil {
|
||||
username := setting.URL.User.Username()
|
||||
password, _ := setting.URL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}, setting.Mode, nil
|
||||
}
|
||||
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
|
||||
default:
|
||||
return nil, setting.Mode, nil
|
||||
}
|
||||
}
|
||||
|
||||
// BuildDialer constructs a proxy dialer for settings that operate at the connection layer.
|
||||
func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
|
||||
setting, errParse := Parse(raw)
|
||||
if errParse != nil {
|
||||
return nil, setting.Mode, errParse
|
||||
}
|
||||
|
||||
switch setting.Mode {
|
||||
case ModeInherit:
|
||||
return nil, setting.Mode, nil
|
||||
case ModeDirect:
|
||||
return proxy.Direct, setting.Mode, nil
|
||||
case ModeProxy:
|
||||
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
|
||||
if errDialer != nil {
|
||||
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
|
||||
}
|
||||
return dialer, setting.Mode, nil
|
||||
default:
|
||||
return nil, setting.Mode, nil
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user