mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 23:33:24 +00:00
Compare commits
164 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f81046730 | ||
|
|
0687472d01 | ||
|
|
7739738fb3 | ||
|
|
99d1ce247b | ||
|
|
f5941a411c | ||
|
|
ba672bbd07 | ||
|
|
d9c6627a53 | ||
|
|
2e9907c3ac | ||
|
|
90afb9cb73 | ||
|
|
d0cc0cd9a5 | ||
|
|
338321e553 | ||
|
|
182b31963a | ||
|
|
4f48e5254a | ||
|
|
15dd5db1d7 | ||
|
|
424711b718 | ||
|
|
91a2b1f0b4 | ||
|
|
2b134fc378 | ||
|
|
b9153719b0 | ||
|
|
631e5c8331 | ||
|
|
e9c60a0a67 | ||
|
|
98a1bb5a7f | ||
|
|
ca90487a8c | ||
|
|
1042489f85 | ||
|
|
38277c1ea6 | ||
|
|
ee0c24628f | ||
|
|
3a18f6fcca | ||
|
|
099e734a02 | ||
|
|
a52da26b5d | ||
|
|
522a68a4ea | ||
|
|
a02eda54d0 | ||
|
|
97ef633c57 | ||
|
|
dae8463ba1 | ||
|
|
7c1299922e | ||
|
|
ddcf1f279d | ||
|
|
7e6bb8fdc5 | ||
|
|
9cee8ef87b | ||
|
|
93fb841bcb | ||
|
|
0c05131aeb | ||
|
|
5ebc58fab4 | ||
|
|
2b609dd891 | ||
|
|
a8cbc68c3e | ||
|
|
11a795a01c | ||
|
|
89c428216e | ||
|
|
2695a99623 | ||
|
|
242aecd924 | ||
|
|
ce8cc1ba33 | ||
|
|
ad5253bd2b | ||
|
|
97fdd2e088 | ||
|
|
9397f7049f | ||
|
|
a14d19b92c | ||
|
|
8ae0c05ea6 | ||
|
|
8822f20d17 | ||
|
|
553d6f50ea | ||
|
|
f0e5a5a367 | ||
|
|
f6dfea9357 | ||
|
|
cc8dc7f62c | ||
|
|
a3846ea513 | ||
|
|
8d44be858e | ||
|
|
0e6bb076e9 | ||
|
|
ac135fc7cb | ||
|
|
4e1d09809d | ||
|
|
9e855f8100 | ||
|
|
25680a8259 | ||
|
|
13c93e8cfd | ||
|
|
88aa1b9fd1 | ||
|
|
352cb98ff0 | ||
|
|
ac95e92829 | ||
|
|
8526c2da25 | ||
|
|
68a6cabf8b | ||
|
|
ac0e387da1 | ||
|
|
7fe1d102cb | ||
|
|
5850492a93 | ||
|
|
fdbd4041ca | ||
|
|
ebef1fae2a | ||
|
|
c51851689b | ||
|
|
419bf784ab | ||
|
|
4bbeb92e9a | ||
|
|
b436dad8bc | ||
|
|
6ae15d6c44 | ||
|
|
0468bde0d6 | ||
|
|
1d7329e797 | ||
|
|
48ffc4dee7 | ||
|
|
7ebd8f0c44 | ||
|
|
b680c146c1 | ||
|
|
7d6660d181 | ||
|
|
d8e3d4e2b6 | ||
|
|
d26ad8224d | ||
|
|
5c84d69d42 | ||
|
|
527e4b7f26 | ||
|
|
b48485b42b | ||
|
|
79009bb3d4 | ||
|
|
26fc611f86 | ||
|
|
b43743d4f1 | ||
|
|
179e5434b1 | ||
|
|
9f95b31158 | ||
|
|
5da07eae4c | ||
|
|
835ae178d4 | ||
|
|
c80ab8bf0d | ||
|
|
ce87714ef1 | ||
|
|
0452b869e8 | ||
|
|
d2e5857b82 | ||
|
|
f9b005f21f | ||
|
|
532107b4fa | ||
|
|
c44793789b | ||
|
|
4e99525279 | ||
|
|
7547d1d0b3 | ||
|
|
68934942d0 | ||
|
|
09fec34e1c | ||
|
|
9229708b6c | ||
|
|
914db94e79 | ||
|
|
660bd7eff5 | ||
|
|
b907d21851 | ||
|
|
dd44413ba5 | ||
|
|
10fa0f2062 | ||
|
|
d6cc976d1f | ||
|
|
8aa2cce8c5 | ||
|
|
bf9b2c49df | ||
|
|
77b42c6165 | ||
|
|
446150a747 | ||
|
|
1cbc4834e1 | ||
|
|
30338ecec4 | ||
|
|
9a37defed3 | ||
|
|
c83a057996 | ||
|
|
a8a5d03c33 | ||
|
|
76aa917882 | ||
|
|
6ac9b31e4e | ||
|
|
0ad3e8457f | ||
|
|
444a47ae63 | ||
|
|
725f4fdff4 | ||
|
|
c23e46f45d | ||
|
|
b148820c35 | ||
|
|
134f41496d | ||
|
|
c5838dd58d | ||
|
|
b6ca5ef7ce | ||
|
|
1ae994b4aa | ||
|
|
84e9793e61 | ||
|
|
32e64dacfd | ||
|
|
cc1d8f6629 | ||
|
|
5446cd2b02 | ||
|
|
8de0885b7d | ||
|
|
16243f18fd | ||
|
|
a6ce5f36e6 | ||
|
|
e73cf42e28 | ||
|
|
b45343e812 | ||
|
|
8599b1560e | ||
|
|
8bde8c37c0 | ||
|
|
68dd2bfe82 | ||
|
|
2baf35b3ef | ||
|
|
846e75b893 | ||
|
|
fc0257d6d9 | ||
|
|
f3c164d345 | ||
|
|
4040b1e766 | ||
|
|
b7588428c5 | ||
|
|
8f97a5f77c | ||
|
|
2a4d3e60f3 | ||
|
|
8b5af2ab84 | ||
|
|
d887716ebd | ||
|
|
5dc1848466 | ||
|
|
9491517b26 | ||
|
|
9370b5bd04 | ||
|
|
abb51a0d93 | ||
|
|
c8d809131b | ||
|
|
dd71c73a9f | ||
|
|
2615f489d6 |
@@ -31,6 +31,7 @@ bin/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -44,6 +44,7 @@ GEMINI.md
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
126
README.md
126
README.md
@@ -8,132 +8,6 @@ All third-party provider support is maintained by community contributors; CLIPro
|
||||
|
||||
The Plus release stays in lockstep with the mainline features.
|
||||
|
||||
## Differences from the Mainline
|
||||
|
||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
||||
|
||||
## New Features (Plus Enhanced)
|
||||
|
||||
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
|
||||
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
|
||||
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
|
||||
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
|
||||
- **Device Fingerprint**: Device fingerprint generation for enhanced security
|
||||
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
|
||||
- **Usage Checker**: Real-time usage monitoring and quota management
|
||||
- **Model Converter**: Unified model name conversion across providers
|
||||
- **UTF-8 Stream Processing**: Improved streaming response handling
|
||||
|
||||
## Kiro Authentication
|
||||
|
||||
### CLI Login
|
||||
|
||||
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
|
||||
**AWS Builder ID** (recommended):
|
||||
|
||||
```bash
|
||||
# Device code flow
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# Authorization code flow
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**Import token from Kiro IDE:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
To get a token from Kiro IDE:
|
||||
|
||||
1. Open Kiro IDE and login with Google (or GitHub)
|
||||
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
||||
3. Run: `./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# Specify region
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**Additional flags:**
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--no-browser` | Don't open browser automatically, print URL instead |
|
||||
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
||||
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
||||
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
||||
|
||||
### Web-based OAuth Login
|
||||
|
||||
Access the Kiro OAuth web interface at:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
||||
- AWS Builder ID login
|
||||
- AWS Identity Center (IDC) login
|
||||
- Token import from Kiro IDE
|
||||
|
||||
## Quick Deployment with Docker
|
||||
|
||||
### One-Command Deployment
|
||||
|
||||
```bash
|
||||
# Create deployment directory
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# Create docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# Download example config
|
||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# Pull and start
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Edit `config.yaml` before starting:
|
||||
|
||||
```yaml
|
||||
# Basic configuration example
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# Add your provider configurations here
|
||||
```
|
||||
|
||||
### Update to Latest Version
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||
|
||||
126
README_CN.md
126
README_CN.md
@@ -8,132 +8,6 @@
|
||||
|
||||
该 Plus 版本的主线功能与主线功能强制同步。
|
||||
|
||||
## 与主线版本版本差异
|
||||
|
||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
||||
|
||||
## 新增功能 (Plus 增强版)
|
||||
|
||||
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
|
||||
- **请求限流器**: 内置请求限流,防止 API 滥用
|
||||
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
|
||||
- **监控指标**: 请求指标收集,用于监控和调试
|
||||
- **设备指纹**: 设备指纹生成,增强安全性
|
||||
- **冷却管理**: 智能冷却机制,应对 API 速率限制
|
||||
- **用量检查器**: 实时用量监控和配额管理
|
||||
- **模型转换器**: 跨供应商的统一模型名称转换
|
||||
- **UTF-8 流处理**: 改进的流式响应处理
|
||||
|
||||
## Kiro 认证
|
||||
|
||||
### 命令行登录
|
||||
|
||||
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
||||
|
||||
**AWS Builder ID**(推荐):
|
||||
|
||||
```bash
|
||||
# 设备码流程
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# 授权码流程
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**从 Kiro IDE 导入令牌:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
获取令牌步骤:
|
||||
|
||||
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
||||
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
||||
3. 运行:`./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# 指定区域
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**附加参数:**
|
||||
|
||||
| 参数 | 说明 |
|
||||
|------|------|
|
||||
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
||||
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
||||
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
||||
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
||||
|
||||
### 网页端 OAuth 登录
|
||||
|
||||
访问 Kiro OAuth 网页认证界面:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
||||
- AWS Builder ID 登录
|
||||
- AWS Identity Center (IDC) 登录
|
||||
- 从 Kiro IDE 导入令牌
|
||||
|
||||
## Docker 快速部署
|
||||
|
||||
### 一键部署
|
||||
|
||||
```bash
|
||||
# 创建部署目录
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# 创建 docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# 下载示例配置
|
||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# 拉取并启动
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### 配置说明
|
||||
|
||||
启动前请编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# 基本配置示例
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# 在此添加你的供应商配置
|
||||
```
|
||||
|
||||
### 更新到最新版本
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## 贡献
|
||||
|
||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||
|
||||
@@ -80,6 +80,10 @@ passthrough-headers: false
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
# Maximum number of different credentials to try for one failed request.
|
||||
# Set to 0 to keep legacy behavior (try all available credentials).
|
||||
max-retry-credentials: 0
|
||||
|
||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||
max-retry-interval: 30
|
||||
|
||||
@@ -215,6 +219,17 @@ nonstream-keepalive-interval: 0
|
||||
# models: # The models supported by the provider.
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# # You may repeat the same alias to build an internal model pool.
|
||||
# # The client still sees only one alias in the model list.
|
||||
# # Requests to that alias will round-robin across the upstream names below,
|
||||
# # and if the chosen upstream fails before producing output, the request will
|
||||
# # continue with the next upstream model in the same alias pool.
|
||||
# - name: "qwen3.5-plus"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "glm-5"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "claude-opus-4.66"
|
||||
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
@@ -229,6 +244,9 @@ nonstream-keepalive-interval: 0
|
||||
# alias: "vertex-flash" # client-visible alias
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "vertex-pro"
|
||||
# excluded-models: # optional: models to exclude from listing
|
||||
# - "imagen-3.0-generate-002"
|
||||
# - "imagen-*"
|
||||
|
||||
# Amp Integration
|
||||
# ampcode:
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -48,14 +49,11 @@ import (
|
||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||
|
||||
const (
|
||||
anthropicCallbackPort = 54545
|
||||
geminiCallbackPort = 8085
|
||||
codexCallbackPort = 1455
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
geminiCLIApiClient = "gl-node/22.17.0"
|
||||
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
anthropicCallbackPort = 54545
|
||||
geminiCallbackPort = 8085
|
||||
codexCallbackPort = 1455
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
)
|
||||
|
||||
type callbackForwarder struct {
|
||||
@@ -195,17 +193,6 @@ func startCallbackForwarder(port int, provider, targetBase string) (*callbackFor
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func stopCallbackForwarder(port int) {
|
||||
callbackForwardersMu.Lock()
|
||||
forwarder := callbackForwarders[port]
|
||||
if forwarder != nil {
|
||||
delete(callbackForwarders, port)
|
||||
}
|
||||
callbackForwardersMu.Unlock()
|
||||
|
||||
stopForwarderInstance(port, forwarder)
|
||||
}
|
||||
|
||||
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||
if forwarder == nil {
|
||||
return
|
||||
@@ -647,44 +634,85 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
full := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
if !filepath.IsAbs(full) {
|
||||
if abs, errAbs := filepath.Abs(full); errAbs == nil {
|
||||
full = abs
|
||||
|
||||
targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
targetID := ""
|
||||
if targetAuth := h.findAuthForDelete(name); targetAuth != nil {
|
||||
targetID = strings.TrimSpace(targetAuth.ID)
|
||||
if path := strings.TrimSpace(authAttribute(targetAuth, "path")); path != "" {
|
||||
targetPath = path
|
||||
}
|
||||
}
|
||||
if err := os.Remove(full); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if !filepath.IsAbs(targetPath) {
|
||||
if abs, errAbs := filepath.Abs(targetPath); errAbs == nil {
|
||||
targetPath = abs
|
||||
}
|
||||
}
|
||||
if errRemove := os.Remove(targetPath); errRemove != nil {
|
||||
if os.IsNotExist(errRemove) {
|
||||
c.JSON(404, gin.H{"error": "file not found"})
|
||||
} else {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", err)})
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)})
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.deleteTokenRecord(ctx, full); err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil {
|
||||
c.JSON(500, gin.H{"error": errDeleteRecord.Error()})
|
||||
return
|
||||
}
|
||||
h.disableAuth(ctx, full)
|
||||
if targetID != "" {
|
||||
h.disableAuth(ctx, targetID)
|
||||
} else {
|
||||
h.disableAuth(ctx, targetPath)
|
||||
}
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *Handler) findAuthForDelete(name string) *coreauth.Auth {
|
||||
if h == nil || h.authManager == nil {
|
||||
return nil
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
if auth, ok := h.authManager.GetByID(name); ok {
|
||||
return auth
|
||||
}
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(auth.FileName) == name {
|
||||
return auth
|
||||
}
|
||||
if filepath.Base(strings.TrimSpace(authAttribute(auth, "path"))) == name {
|
||||
return auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) authIDForPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if h == nil || h.cfg == nil {
|
||||
return path
|
||||
id := path
|
||||
if h != nil && h.cfg != nil {
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if authDir != "" {
|
||||
if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
}
|
||||
}
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if authDir == "" {
|
||||
return path
|
||||
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
|
||||
if runtime.GOOS == "windows" {
|
||||
id = strings.ToLower(id)
|
||||
}
|
||||
if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" {
|
||||
return rel
|
||||
}
|
||||
return path
|
||||
return id
|
||||
}
|
||||
|
||||
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
|
||||
@@ -902,10 +930,19 @@ func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
}
|
||||
authID := h.authIDForPath(id)
|
||||
if authID == "" {
|
||||
authID = strings.TrimSpace(id)
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if auth, ok := h.authManager.GetByID(id); ok {
|
||||
auth.Disabled = true
|
||||
auth.Status = coreauth.StatusDisabled
|
||||
auth.StatusMessage = "removed via management API"
|
||||
auth.UpdatedAt = time.Now()
|
||||
_, _ = h.authManager.Update(ctx, auth)
|
||||
return
|
||||
}
|
||||
authID := h.authIDForPath(id)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
@@ -1275,12 +1312,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
@@ -1289,7 +1326,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
ts.Auto = false
|
||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
@@ -1300,19 +1337,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1325,13 +1362,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2384,9 +2421,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
|
||||
return fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
||||
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
@@ -2456,7 +2491,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||
@@ -2477,7 +2512,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
resp, errDo = httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||
@@ -2554,6 +2589,7 @@ func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||
}
|
||||
return coreauth.WithRequestInfo(ctx, info)
|
||||
}
|
||||
|
||||
const kiroCallbackPort = 9876
|
||||
|
||||
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
@@ -2690,6 +2726,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
||||
if errTarget != nil {
|
||||
@@ -2697,7 +2734,8 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -2706,7 +2744,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(kiroCallbackPort)
|
||||
defer stopCallbackForwarderInstance(kiroCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
||||
@@ -2909,7 +2947,7 @@ func (h *Handler) RequestKiloToken(c *gin.Context) {
|
||||
Metadata: map[string]any{
|
||||
"email": status.UserEmail,
|
||||
"organization_id": orgID,
|
||||
"model": defaults.Model,
|
||||
"model": defaults.Model,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestDeleteAuthFile_UsesAuthPathFromManager(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
authDir := filepath.Join(tempDir, "auth")
|
||||
externalDir := filepath.Join(tempDir, "external")
|
||||
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||
}
|
||||
if errMkdirExternal := os.MkdirAll(externalDir, 0o700); errMkdirExternal != nil {
|
||||
t.Fatalf("failed to create external dir: %v", errMkdirExternal)
|
||||
}
|
||||
|
||||
fileName := "codex-user@example.com-plus.json"
|
||||
shadowPath := filepath.Join(authDir, fileName)
|
||||
realPath := filepath.Join(externalDir, fileName)
|
||||
if errWriteShadow := os.WriteFile(shadowPath, []byte(`{"type":"codex","email":"shadow@example.com"}`), 0o600); errWriteShadow != nil {
|
||||
t.Fatalf("failed to write shadow file: %v", errWriteShadow)
|
||||
}
|
||||
if errWriteReal := os.WriteFile(realPath, []byte(`{"type":"codex","email":"real@example.com"}`), 0o600); errWriteReal != nil {
|
||||
t.Fatalf("failed to write real file: %v", errWriteReal)
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
record := &coreauth.Auth{
|
||||
ID: "legacy/" + fileName,
|
||||
FileName: fileName,
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusError,
|
||||
Unavailable: true,
|
||||
Attributes: map[string]string{
|
||||
"path": realPath,
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"type": "codex",
|
||||
"email": "real@example.com",
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
deleteRec := httptest.NewRecorder()
|
||||
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||
deleteCtx.Request = deleteReq
|
||||
h.DeleteAuthFile(deleteCtx)
|
||||
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
if _, errStatReal := os.Stat(realPath); !os.IsNotExist(errStatReal) {
|
||||
t.Fatalf("expected managed auth file to be removed, stat err: %v", errStatReal)
|
||||
}
|
||||
if _, errStatShadow := os.Stat(shadowPath); errStatShadow != nil {
|
||||
t.Fatalf("expected shadow auth file to remain, stat err: %v", errStatShadow)
|
||||
}
|
||||
|
||||
listRec := httptest.NewRecorder()
|
||||
listCtx, _ := gin.CreateTestContext(listRec)
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
|
||||
listCtx.Request = listReq
|
||||
h.ListAuthFiles(listCtx)
|
||||
|
||||
if listRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, listRec.Code, listRec.Body.String())
|
||||
}
|
||||
var listPayload map[string]any
|
||||
if errUnmarshal := json.Unmarshal(listRec.Body.Bytes(), &listPayload); errUnmarshal != nil {
|
||||
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
|
||||
}
|
||||
filesRaw, ok := listPayload["files"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected files array, payload: %#v", listPayload)
|
||||
}
|
||||
if len(filesRaw) != 0 {
|
||||
t.Fatalf("expected removed auth to be hidden from list, got %d entries", len(filesRaw))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAuthFile_FallbackToAuthDirPath(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
fileName := "fallback-user.json"
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex"}`), 0o600); errWrite != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWrite)
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
deleteRec := httptest.NewRecorder()
|
||||
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||
deleteCtx.Request = deleteReq
|
||||
h.DeleteAuthFile(deleteCtx)
|
||||
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
if _, errStat := os.Stat(filePath); !os.IsNotExist(errStat) {
|
||||
t.Fatalf("expected auth file to be removed from auth dir, stat err: %v", errStat)
|
||||
}
|
||||
}
|
||||
@@ -516,12 +516,13 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
||||
}
|
||||
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
||||
type vertexCompatPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
Models *[]config.VertexCompatModel `json:"models"`
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
Models *[]config.VertexCompatModel `json:"models"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
@@ -585,6 +586,9 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
normalizeVertexCompatKey(&entry)
|
||||
h.cfg.VertexCompatAPIKey[targetIndex] = entry
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
@@ -1029,6 +1033,7 @@ func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -77,6 +78,9 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
req.Header.Del("X-Api-Key")
|
||||
req.Header.Del("X-Goog-Api-Key")
|
||||
|
||||
// Remove proxy, client identity, and browser fingerprint headers
|
||||
misc.ScrubProxyAndFingerprintHeaders(req)
|
||||
|
||||
// Remove query-based credentials if they match the authenticated client API key.
|
||||
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||
// breaking unrelated upstream query parameters.
|
||||
|
||||
@@ -60,10 +60,8 @@ type ServerOption func(*serverOptionConfig)
|
||||
|
||||
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
||||
configDir := filepath.Dir(configPath)
|
||||
if base := util.WritablePath(); base != "" {
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir, cfg.ErrorLogsMaxFiles)
|
||||
}
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir, cfg.ErrorLogsMaxFiles)
|
||||
logsDir := logging.ResolveLogDirectory(cfg)
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, logsDir, configDir, cfg.ErrorLogsMaxFiles)
|
||||
}
|
||||
|
||||
// WithMiddleware appends additional Gin middleware during server construction.
|
||||
@@ -260,7 +258,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
s.applyAccessConfig(nil, cfg)
|
||||
if authManager != nil {
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
@@ -946,7 +944,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||
}
|
||||
|
||||
// Update log level dynamically when debug flag changes
|
||||
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gin "github.com/gin-gonic/gin"
|
||||
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
@@ -109,3 +111,100 @@ func TestAmpProviderModelRoutes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||
t.Setenv("WRITABLE_PATH", "")
|
||||
t.Setenv("writable_path", "")
|
||||
|
||||
originalWD, errGetwd := os.Getwd()
|
||||
if errGetwd != nil {
|
||||
t.Fatalf("failed to get current working directory: %v", errGetwd)
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
if errChdir := os.Chdir(tmpDir); errChdir != nil {
|
||||
t.Fatalf("failed to switch working directory: %v", errChdir)
|
||||
}
|
||||
defer func() {
|
||||
if errChdirBack := os.Chdir(originalWD); errChdirBack != nil {
|
||||
t.Fatalf("failed to restore working directory: %v", errChdirBack)
|
||||
}
|
||||
}()
|
||||
|
||||
// Force ResolveLogDirectory to fallback to auth-dir/logs by making ./logs not a writable directory.
|
||||
if errWriteFile := os.WriteFile(filepath.Join(tmpDir, "logs"), []byte("not-a-directory"), 0o644); errWriteFile != nil {
|
||||
t.Fatalf("failed to create blocking logs file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(tmpDir, "config")
|
||||
if errMkdirConfig := os.MkdirAll(configDir, 0o755); errMkdirConfig != nil {
|
||||
t.Fatalf("failed to create config dir: %v", errMkdirConfig)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "config.yaml")
|
||||
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||
}
|
||||
|
||||
cfg := &proxyconfig.Config{
|
||||
SDKConfig: proxyconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
},
|
||||
AuthDir: authDir,
|
||||
ErrorLogsMaxFiles: 10,
|
||||
}
|
||||
|
||||
logger := defaultRequestLoggerFactory(cfg, configPath)
|
||||
fileLogger, ok := logger.(*internallogging.FileRequestLogger)
|
||||
if !ok {
|
||||
t.Fatalf("expected *FileRequestLogger, got %T", logger)
|
||||
}
|
||||
|
||||
errLog := fileLogger.LogRequestWithOptions(
|
||||
"/v1/chat/completions",
|
||||
http.MethodPost,
|
||||
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||
[]byte(`{"input":"hello"}`),
|
||||
http.StatusBadGateway,
|
||||
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||
[]byte(`{"error":"upstream failure"}`),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
"issue-1711",
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
)
|
||||
if errLog != nil {
|
||||
t.Fatalf("failed to write forced error request log: %v", errLog)
|
||||
}
|
||||
|
||||
authLogsDir := filepath.Join(authDir, "logs")
|
||||
authEntries, errReadAuthDir := os.ReadDir(authLogsDir)
|
||||
if errReadAuthDir != nil {
|
||||
t.Fatalf("failed to read auth logs dir %s: %v", authLogsDir, errReadAuthDir)
|
||||
}
|
||||
foundErrorLogInAuthDir := false
|
||||
for _, entry := range authEntries {
|
||||
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||
foundErrorLogInAuthDir = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundErrorLogInAuthDir {
|
||||
t.Fatalf("expected forced error log in auth fallback dir %s, got entries: %+v", authLogsDir, authEntries)
|
||||
}
|
||||
|
||||
configLogsDir := filepath.Join(configDir, "logs")
|
||||
configEntries, errReadConfigDir := os.ReadDir(configLogsDir)
|
||||
if errReadConfigDir != nil && !os.IsNotExist(errReadConfigDir) {
|
||||
t.Fatalf("failed to inspect config logs dir %s: %v", configLogsDir, errReadConfigDir)
|
||||
}
|
||||
for _, entry := range configEntries {
|
||||
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||
t.Fatalf("unexpected forced error log in config dir %s", configLogsDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Firefox fingerprint
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
|
||||
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||
type utlsRoundTripper struct {
|
||||
// mu protects the connections map and pending map
|
||||
@@ -100,7 +100,9 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
// createConnection creates a new HTTP/2 connection with Firefox TLS fingerprint
|
||||
// createConnection creates a new HTTP/2 connection with Chrome TLS fingerprint.
|
||||
// Chrome's TLS fingerprint is closer to Node.js/OpenSSL (which real Claude Code uses)
|
||||
// than Firefox, reducing the mismatch between TLS layer and HTTP headers.
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := t.dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
@@ -108,7 +110,7 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloFirefox_Auto)
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
@@ -156,7 +158,7 @@ func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
|
||||
// for Anthropic domains by using utls with Firefox fingerprint.
|
||||
// for Anthropic domains by using utls with Chrome fingerprint.
|
||||
// It accepts optional SDK configuration for proxy settings.
|
||||
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
|
||||
return &http.Client{
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -222,6 +224,97 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// CopilotModelEntry represents a single model entry returned by the Copilot /models API.
|
||||
type CopilotModelEntry struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
||||
type CopilotModelsResponse struct {
|
||||
Data []CopilotModelEntry `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// maxModelsResponseSize is the maximum allowed response size from the /models endpoint (2 MB).
|
||||
const maxModelsResponseSize = 2 * 1024 * 1024
|
||||
|
||||
// allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests.
|
||||
var allowedCopilotAPIHosts = map[string]bool{
|
||||
"api.githubcopilot.com": true,
|
||||
"api.individual.githubcopilot.com": true,
|
||||
"api.business.githubcopilot.com": true,
|
||||
"copilot-proxy.githubusercontent.com": true,
|
||||
}
|
||||
|
||||
// ListModels fetches the list of available models from the Copilot API.
|
||||
// It requires a valid Copilot API token (not the GitHub access token).
|
||||
func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken) ([]CopilotModelEntry, error) {
|
||||
if apiToken == nil || apiToken.Token == "" {
|
||||
return nil, fmt.Errorf("copilot: api token is required for listing models")
|
||||
}
|
||||
|
||||
// Build models URL, validating the endpoint host to prevent SSRF.
|
||||
modelsURL := copilotAPIEndpoint + "/models"
|
||||
if ep := strings.TrimRight(apiToken.Endpoints.API, "/"); ep != "" {
|
||||
parsed, err := url.Parse(ep)
|
||||
if err == nil && parsed.Scheme == "https" && allowedCopilotAPIHosts[parsed.Host] {
|
||||
modelsURL = ep + "/models"
|
||||
} else {
|
||||
log.Warnf("copilot: ignoring untrusted API endpoint %q, using default", ep)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := c.MakeAuthenticatedRequest(ctx, http.MethodGet, modelsURL, nil, apiToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to create models request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: models request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("copilot list models: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// Limit response body to prevent memory exhaustion.
|
||||
limitedReader := io.LimitReader(resp.Body, maxModelsResponseSize)
|
||||
bodyBytes, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to read models response: %w", err)
|
||||
}
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
return nil, fmt.Errorf("copilot: list models failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var modelsResp CopilotModelsResponse
|
||||
if err = json.Unmarshal(bodyBytes, &modelsResp); err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to parse models response: %w", err)
|
||||
}
|
||||
|
||||
return modelsResp.Data, nil
|
||||
}
|
||||
|
||||
// ListModelsWithGitHubToken is a convenience method that exchanges a GitHub access token
|
||||
// for a Copilot API token and then fetches the available models.
|
||||
func (c *CopilotAuth) ListModelsWithGitHubToken(ctx context.Context, githubAccessToken string) ([]CopilotModelEntry, error) {
|
||||
apiToken, err := c.GetCopilotAPIToken(ctx, githubAccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to get API token for model listing: %w", err)
|
||||
}
|
||||
|
||||
return c.ListModels(ctx, apiToken)
|
||||
}
|
||||
|
||||
// buildChatCompletionURL builds the URL for chat completions API.
|
||||
func buildChatCompletionURL() string {
|
||||
return copilotAPIEndpoint + "/chat/completions"
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -27,11 +28,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
geminiCLIApiClient = "gl-node/22.17.0"
|
||||
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
)
|
||||
|
||||
type projectSelectionRequiredError struct{}
|
||||
@@ -409,9 +407,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
|
||||
return fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
||||
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
@@ -630,7 +626,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||
@@ -651,7 +647,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
resp, errDo = httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||
|
||||
@@ -69,6 +69,9 @@ type Config struct {
|
||||
|
||||
// RequestRetry defines the retry times when the request failed.
|
||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
||||
// Set to 0 or a negative value to keep trying all available credentials (legacy behavior).
|
||||
MaxRetryCredentials int `yaml:"max-retry-credentials" json:"max-retry-credentials"`
|
||||
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
||||
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
|
||||
|
||||
@@ -576,16 +579,6 @@ func LoadConfig(configFile string) (*Config, error) {
|
||||
// If optional is true and the file is missing, it returns an empty Config.
|
||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// // Log warning but don't fail - config loading should still work
|
||||
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
// } else if migrated {
|
||||
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
// }
|
||||
|
||||
// Read the entire configuration file into memory.
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
@@ -673,6 +666,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
}
|
||||
|
||||
if cfg.MaxRetryCredentials < 0 {
|
||||
cfg.MaxRetryCredentials = 0
|
||||
}
|
||||
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
@@ -1669,9 +1666,6 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
|
||||
srcIdx := findMapKeyIndex(srcRoot, key)
|
||||
if srcIdx < 0 {
|
||||
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
|
||||
//
|
||||
// Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the
|
||||
// oauth-model-alias key is missing, migration will add the default antigravity aliases.
|
||||
// When users delete the last channel from oauth-model-alias via the management API,
|
||||
// we want that deletion to persist across hot reloads and restarts.
|
||||
if key == "oauth-model-alias" {
|
||||
|
||||
61
internal/config/oauth_model_alias_defaults.go
Normal file
61
internal/config/oauth_model_alias_defaults.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
|
||||
// defaultKiroAliases returns default oauth-model-alias entries for Kiro.
|
||||
// These aliases expose standard Claude IDs for Kiro-prefixed upstream models.
|
||||
func defaultKiroAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
// Sonnet 4.6
|
||||
{Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||
// Sonnet 4.5
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
// Sonnet 4
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||
// Opus 4.6
|
||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||
// Opus 4.5
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||
// Haiku 4.5
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultGitHubCopilotAliases returns default oauth-model-alias entries for
|
||||
// GitHub Copilot Claude models. It exposes hyphen-style IDs used by clients.
|
||||
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
||||
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
||||
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
||||
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// GitHubCopilotAliasesFromModels generates oauth-model-alias entries from a dynamic
|
||||
// list of model IDs fetched from the Copilot API. It auto-creates aliases for
|
||||
// models whose ID contains a dot (e.g. "claude-opus-4.6" → "claude-opus-4-6"),
|
||||
// which is the pattern used by Claude models on Copilot.
|
||||
func GitHubCopilotAliasesFromModels(modelIDs []string) []OAuthModelAlias {
|
||||
var aliases []OAuthModelAlias
|
||||
seen := make(map[string]struct{})
|
||||
for _, id := range modelIDs {
|
||||
if !strings.Contains(id, ".") {
|
||||
continue
|
||||
}
|
||||
hyphenID := strings.ReplaceAll(id, ".", "-")
|
||||
key := id + "→" + hyphenID
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
aliases = append(aliases, OAuthModelAlias{Name: id, Alias: hyphenID, Fork: true})
|
||||
}
|
||||
return aliases
|
||||
}
|
||||
@@ -1,316 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// antigravityModelConversionTable maps old built-in aliases to actual model names
|
||||
// for the antigravity channel during migration.
|
||||
var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
}
|
||||
|
||||
// defaultKiroAliases returns the default oauth-model-alias configuration
|
||||
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
|
||||
// names so that clients like Claude Code can use standard names directly.
|
||||
func defaultKiroAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
// Sonnet 4.6
|
||||
{Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||
// Sonnet 4.5
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
// Sonnet 4
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||
// Opus 4.6
|
||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||
// Opus 4.5
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||
// Haiku 4.5
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
|
||||
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
|
||||
// This keeps compatibility with clients (e.g. Claude Code) that use
|
||||
// Anthropic-style model IDs like "claude-opus-4-6".
|
||||
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
||||
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
||||
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
||||
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
// for the antigravity channel when neither field exists.
|
||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
|
||||
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
|
||||
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
|
||||
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
|
||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
||||
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
|
||||
}
|
||||
}
|
||||
|
||||
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
|
||||
// to oauth-model-alias at startup. Returns true if migration was performed.
|
||||
//
|
||||
// Migration flow:
|
||||
// 1. Check if oauth-model-alias exists -> skip migration
|
||||
// 2. Check if oauth-model-mappings exists -> convert and migrate
|
||||
// - For antigravity channel, convert old built-in aliases to actual model names
|
||||
//
|
||||
// 3. Neither exists -> add default antigravity config
|
||||
func MigrateOAuthModelAlias(configFile string) (bool, error) {
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Parse YAML into node tree to preserve structure
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
rootMap := root.Content[0]
|
||||
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if oauth-model-alias already exists
|
||||
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if oauth-model-mappings exists
|
||||
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
|
||||
if oldIdx >= 0 {
|
||||
// Migrate from old field
|
||||
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
|
||||
}
|
||||
|
||||
// Neither field exists - add default antigravity config
|
||||
return addDefaultAntigravityConfig(configFile, &root, rootMap)
|
||||
}
|
||||
|
||||
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
|
||||
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
|
||||
if oldIdx+1 >= len(rootMap.Content) {
|
||||
return false, nil
|
||||
}
|
||||
oldValue := rootMap.Content[oldIdx+1]
|
||||
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Parse the old aliases
|
||||
oldAliases := parseOldAliasNode(oldValue)
|
||||
if len(oldAliases) == 0 {
|
||||
// Remove the old field and write
|
||||
removeMapKeyByIndex(rootMap, oldIdx)
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// Convert model names for antigravity channel
|
||||
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
|
||||
for channel, entries := range oldAliases {
|
||||
converted := make([]OAuthModelAlias, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
newEntry := OAuthModelAlias{
|
||||
Name: entry.Name,
|
||||
Alias: entry.Alias,
|
||||
Fork: entry.Fork,
|
||||
}
|
||||
// Convert model names for antigravity channel
|
||||
if strings.EqualFold(channel, "antigravity") {
|
||||
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
|
||||
newEntry.Name = actual
|
||||
}
|
||||
}
|
||||
converted = append(converted, newEntry)
|
||||
}
|
||||
newAliases[channel] = converted
|
||||
}
|
||||
|
||||
// For antigravity channel, supplement missing default aliases
|
||||
if antigravityEntries, exists := newAliases["antigravity"]; exists {
|
||||
// Build a set of already configured model names (upstream names)
|
||||
configuredModels := make(map[string]bool, len(antigravityEntries))
|
||||
for _, entry := range antigravityEntries {
|
||||
configuredModels[entry.Name] = true
|
||||
}
|
||||
|
||||
// Add missing default aliases
|
||||
for _, defaultAlias := range defaultAntigravityAliases() {
|
||||
if !configuredModels[defaultAlias.Name] {
|
||||
antigravityEntries = append(antigravityEntries, defaultAlias)
|
||||
}
|
||||
}
|
||||
newAliases["antigravity"] = antigravityEntries
|
||||
}
|
||||
|
||||
// Build new node
|
||||
newNode := buildOAuthModelAliasNode(newAliases)
|
||||
|
||||
// Replace old key with new key and value
|
||||
rootMap.Content[oldIdx].Value = "oauth-model-alias"
|
||||
rootMap.Content[oldIdx+1] = newNode
|
||||
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// addDefaultAntigravityConfig adds the default antigravity configuration
|
||||
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
|
||||
defaults := map[string][]OAuthModelAlias{
|
||||
"antigravity": defaultAntigravityAliases(),
|
||||
}
|
||||
newNode := buildOAuthModelAliasNode(defaults)
|
||||
|
||||
// Add new key-value pair
|
||||
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
|
||||
rootMap.Content = append(rootMap.Content, keyNode, newNode)
|
||||
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// parseOldAliasNode parses the old oauth-model-mappings node structure
|
||||
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
|
||||
if node == nil || node.Kind != yaml.MappingNode {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string][]OAuthModelAlias)
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
channelNode := node.Content[i]
|
||||
entriesNode := node.Content[i+1]
|
||||
if channelNode == nil || entriesNode == nil {
|
||||
continue
|
||||
}
|
||||
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
|
||||
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
|
||||
continue
|
||||
}
|
||||
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
|
||||
for _, entryNode := range entriesNode.Content {
|
||||
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
|
||||
continue
|
||||
}
|
||||
entry := parseAliasEntry(entryNode)
|
||||
if entry.Name != "" && entry.Alias != "" {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
if len(entries) > 0 {
|
||||
result[channel] = entries
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseAliasEntry parses a single alias entry node
|
||||
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
|
||||
var entry OAuthModelAlias
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
keyNode := node.Content[i]
|
||||
valNode := node.Content[i+1]
|
||||
if keyNode == nil || valNode == nil {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
|
||||
case "name":
|
||||
entry.Name = strings.TrimSpace(valNode.Value)
|
||||
case "alias":
|
||||
entry.Alias = strings.TrimSpace(valNode.Value)
|
||||
case "fork":
|
||||
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
|
||||
}
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
|
||||
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
|
||||
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
for channel, entries := range aliases {
|
||||
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
|
||||
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
|
||||
for _, entry := range entries {
|
||||
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
entryNode.Content = append(entryNode.Content,
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
|
||||
)
|
||||
if entry.Fork {
|
||||
entryNode.Content = append(entryNode.Content,
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
|
||||
)
|
||||
}
|
||||
entriesNode.Content = append(entriesNode.Content, entryNode)
|
||||
}
|
||||
node.Content = append(node.Content, channelNode, entriesNode)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
|
||||
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
|
||||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||||
return
|
||||
}
|
||||
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
|
||||
return
|
||||
}
|
||||
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
|
||||
}
|
||||
|
||||
// writeYAMLNode writes the YAML node tree back to file
|
||||
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
|
||||
f, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := yaml.NewEncoder(f)
|
||||
enc.SetIndent(2)
|
||||
if err := enc.Encode(root); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@@ -1,245 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `oauth-model-alias:
|
||||
gemini-cli:
|
||||
- name: "gemini-2.5-pro"
|
||||
alias: "g2.5p"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration when oauth-model-alias already exists")
|
||||
}
|
||||
|
||||
// Verify file unchanged
|
||||
data, _ := os.ReadFile(configFile)
|
||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
||||
t.Fatal("file should still contain oauth-model-alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `oauth-model-mappings:
|
||||
gemini-cli:
|
||||
- name: "gemini-2.5-pro"
|
||||
alias: "g2.5p"
|
||||
fork: true
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify new field exists and old field removed
|
||||
data, _ := os.ReadFile(configFile)
|
||||
if strings.Contains(string(data), "oauth-model-mappings:") {
|
||||
t.Fatal("old field should be removed")
|
||||
}
|
||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
||||
t.Fatal("new field should exist")
|
||||
}
|
||||
|
||||
// Parse and verify structure
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Use old model names that should be converted
|
||||
content := `oauth-model-mappings:
|
||||
antigravity:
|
||||
- name: "gemini-2.5-computer-use-preview-10-2025"
|
||||
alias: "computer-use"
|
||||
- name: "gemini-3-pro-preview"
|
||||
alias: "g3p"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify model names were converted
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
||||
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
|
||||
}
|
||||
if !strings.Contains(content, "gemini-3-pro-high") {
|
||||
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
|
||||
}
|
||||
|
||||
// Verify missing default aliases were supplemented
|
||||
if !strings.Contains(content, "gemini-3-pro-image") {
|
||||
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
|
||||
}
|
||||
if !strings.Contains(content, "gemini-3-flash") {
|
||||
t.Fatal("expected missing default alias gemini-3-flash to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-sonnet-4-5") {
|
||||
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-6-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `debug: true
|
||||
port: 8080
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to add default config")
|
||||
}
|
||||
|
||||
// Verify default antigravity config was added
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "oauth-model-alias:") {
|
||||
t.Fatal("expected oauth-model-alias to be added")
|
||||
}
|
||||
if !strings.Contains(content, "antigravity:") {
|
||||
t.Fatal("expected antigravity channel to be added")
|
||||
}
|
||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
||||
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `debug: true
|
||||
port: 8080
|
||||
oauth-model-mappings:
|
||||
gemini-cli:
|
||||
- name: "test"
|
||||
alias: "t"
|
||||
api-keys:
|
||||
- "key1"
|
||||
- "key2"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify other config preserved
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "debug: true") {
|
||||
t.Fatal("expected debug field to be preserved")
|
||||
}
|
||||
if !strings.Contains(content, "port: 8080") {
|
||||
t.Fatal("expected port field to be preserved")
|
||||
}
|
||||
if !strings.Contains(content, "api-keys:") {
|
||||
t.Fatal("expected api-keys field to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for nonexistent file: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration for empty file")
|
||||
}
|
||||
}
|
||||
@@ -34,6 +34,9 @@ type VertexCompatKey struct {
|
||||
|
||||
// Models defines the model configurations including aliases for routing.
|
||||
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||
|
||||
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
|
||||
@@ -74,6 +77,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
}
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
|
||||
// Sanitize models: remove entries without valid alias
|
||||
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
|
||||
|
||||
@@ -1 +1 @@
|
||||
[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude.","cache_control":{"type":"ephemeral"}}]
|
||||
[{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}]
|
||||
@@ -4,10 +4,98 @@
|
||||
package misc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
|
||||
GeminiCLIVersion = "0.31.0"
|
||||
|
||||
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
|
||||
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
|
||||
)
|
||||
|
||||
// geminiCLIOS maps Go runtime OS names to the Node.js-style platform strings used by Gemini CLI.
|
||||
func geminiCLIOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return "win32"
|
||||
default:
|
||||
return runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// geminiCLIArch maps Go runtime architecture names to the Node.js-style arch strings used by Gemini CLI.
|
||||
func geminiCLIArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
// GeminiCLIUserAgent returns a User-Agent string that matches the Gemini CLI format.
|
||||
// The model parameter is included in the UA; pass "" or "unknown" when the model is not applicable.
|
||||
func GeminiCLIUserAgent(model string) string {
|
||||
if model == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
|
||||
}
|
||||
|
||||
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
|
||||
// proxy infrastructure, client identity, or browser fingerprints from an
|
||||
// outgoing request. This ensures requests to upstream services look like they
|
||||
// originate directly from a native client rather than a third-party client
|
||||
// behind a reverse proxy.
|
||||
func ScrubProxyAndFingerprintHeaders(req *http.Request) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// --- Proxy tracing headers ---
|
||||
req.Header.Del("X-Forwarded-For")
|
||||
req.Header.Del("X-Forwarded-Host")
|
||||
req.Header.Del("X-Forwarded-Proto")
|
||||
req.Header.Del("X-Forwarded-Port")
|
||||
req.Header.Del("X-Real-IP")
|
||||
req.Header.Del("Forwarded")
|
||||
req.Header.Del("Via")
|
||||
|
||||
// --- Client identity headers ---
|
||||
req.Header.Del("X-Title")
|
||||
req.Header.Del("X-Stainless-Lang")
|
||||
req.Header.Del("X-Stainless-Package-Version")
|
||||
req.Header.Del("X-Stainless-Os")
|
||||
req.Header.Del("X-Stainless-Arch")
|
||||
req.Header.Del("X-Stainless-Runtime")
|
||||
req.Header.Del("X-Stainless-Runtime-Version")
|
||||
req.Header.Del("Http-Referer")
|
||||
req.Header.Del("Referer")
|
||||
|
||||
// --- Browser / Chromium fingerprint headers ---
|
||||
// These are sent by Electron-based clients (e.g. CherryStudio) using the
|
||||
// Fetch API, but NOT by Node.js https module (which Antigravity uses).
|
||||
req.Header.Del("Sec-Ch-Ua")
|
||||
req.Header.Del("Sec-Ch-Ua-Mobile")
|
||||
req.Header.Del("Sec-Ch-Ua-Platform")
|
||||
req.Header.Del("Sec-Fetch-Mode")
|
||||
req.Header.Del("Sec-Fetch-Site")
|
||||
req.Header.Del("Sec-Fetch-Dest")
|
||||
req.Header.Del("Priority")
|
||||
|
||||
// --- Encoding negotiation ---
|
||||
// Antigravity (Node.js) sends "gzip, deflate, br" by default;
|
||||
// Electron-based clients may add "zstd" which is a fingerprint mismatch.
|
||||
req.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
// EnsureHeader ensures that a header exists in the target header map by checking
|
||||
// multiple sources in order of priority: source headers, existing target headers,
|
||||
// and finally the default value. It only sets the header if it's not already present
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
// - kiro
|
||||
// - kilo
|
||||
// - github-copilot
|
||||
// - kiro
|
||||
// - amazonq
|
||||
// - antigravity (returns static overrides only)
|
||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
@@ -152,6 +151,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "OpenAI GPT-4.1 via GitHub Copilot",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: entry.Description,
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
@@ -49,7 +49,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 1000000,
|
||||
MaxCompletionTokens: 128000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high", "max"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
@@ -211,6 +211,21 @@ func GetGeminiModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-flash-image-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-flash-image-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Flash Image Preview",
|
||||
Description: "Gemini 3.1 Flash Image Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -220,12 +235,27 @@ func GetGeminiModels() []*ModelInfo {
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-flash-lite-preview",
|
||||
Object: "model",
|
||||
Created: 1776288000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-flash-lite-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Flash Lite Preview",
|
||||
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
Object: "model",
|
||||
@@ -336,6 +366,32 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-flash-image-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-flash-image-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Flash Image Preview",
|
||||
Description: "Gemini 3.1 Flash Image Preview",
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-flash-lite-preview",
|
||||
Object: "model",
|
||||
Created: 1776288000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-flash-lite-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Flash Lite Preview",
|
||||
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
Object: "model",
|
||||
@@ -508,6 +564,21 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-flash-lite-preview",
|
||||
Object: "model",
|
||||
Created: 1776288000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-flash-lite-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Flash Lite Preview",
|
||||
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -604,6 +675,21 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-flash-lite-preview",
|
||||
Object: "model",
|
||||
Created: 1776288000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-flash-lite-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Flash Lite Preview",
|
||||
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-pro-latest",
|
||||
Object: "model",
|
||||
@@ -839,6 +925,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.4",
|
||||
Object: "model",
|
||||
Created: 1772668800,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.4",
|
||||
DisplayName: "GPT 5.4",
|
||||
Description: "Stable version of GPT 5.4",
|
||||
ContextLength: 1_050_000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -959,22 +1059,18 @@ type AntigravityModelConfig struct {
|
||||
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
return map[string]*AntigravityModelConfig{
|
||||
// "rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
|
||||
"gemini-3.1-flash-lite-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ type ModelInfo struct {
|
||||
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||
// SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
|
||||
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
|
||||
// SupportedInputModalities lists supported input modalities (e.g., TEXT, IMAGE, VIDEO, AUDIO)
|
||||
SupportedInputModalities []string `json:"supportedInputModalities,omitempty"`
|
||||
// SupportedOutputModalities lists supported output modalities (e.g., TEXT, IMAGE)
|
||||
SupportedOutputModalities []string `json:"supportedOutputModalities,omitempty"`
|
||||
|
||||
// Thinking holds provider-specific reasoning/thinking budget capabilities.
|
||||
// This is optional and currently used for Gemini thinking budget normalization.
|
||||
@@ -60,6 +64,11 @@ type ModelInfo struct {
|
||||
UserDefined bool `json:"-"`
|
||||
}
|
||||
|
||||
type availableModelsCacheEntry struct {
|
||||
models []map[string]any
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
||||
// Values are interpreted in provider-native token units.
|
||||
type ThinkingSupport struct {
|
||||
@@ -114,6 +123,8 @@ type ModelRegistry struct {
|
||||
clientProviders map[string]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
mutex *sync.RWMutex
|
||||
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
|
||||
availableModelsCache map[string]availableModelsCacheEntry
|
||||
// hook is an optional callback sink for model registration changes
|
||||
hook ModelRegistryHook
|
||||
}
|
||||
@@ -126,15 +137,28 @@ var registryOnce sync.Once
|
||||
func GetGlobalRegistry() *ModelRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
availableModelsCache: make(map[string]availableModelsCacheEntry),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
|
||||
if r.availableModelsCache == nil {
|
||||
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
|
||||
if len(r.availableModelsCache) == 0 {
|
||||
return
|
||||
}
|
||||
clear(r.availableModelsCache)
|
||||
}
|
||||
|
||||
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||
@@ -149,9 +173,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||
}
|
||||
|
||||
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||
return info
|
||||
return cloneModelInfo(info)
|
||||
}
|
||||
return LookupStaticModelInfo(modelID)
|
||||
return cloneModelInfo(LookupStaticModelInfo(modelID))
|
||||
}
|
||||
|
||||
// SetHook sets an optional hook for observing model registration changes.
|
||||
@@ -209,6 +233,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
|
||||
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
provider := strings.ToLower(clientProvider)
|
||||
uniqueModelIDs := make([]string, 0, len(models))
|
||||
@@ -234,6 +259,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
delete(r.clientProviders, clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
misc.LogCredentialSeparator()
|
||||
return
|
||||
}
|
||||
@@ -261,6 +287,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
} else {
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
r.triggerModelsRegistered(provider, clientID, models)
|
||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||
misc.LogCredentialSeparator()
|
||||
@@ -404,6 +431,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
r.triggerModelsRegistered(provider, clientID, models)
|
||||
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||
@@ -501,8 +529,18 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if len(model.SupportedEndpoints) > 0 {
|
||||
copyModel.SupportedEndpoints = append([]string(nil), model.SupportedEndpoints...)
|
||||
if len(model.SupportedInputModalities) > 0 {
|
||||
copyModel.SupportedInputModalities = append([]string(nil), model.SupportedInputModalities...)
|
||||
}
|
||||
if len(model.SupportedOutputModalities) > 0 {
|
||||
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
copyThinking := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
copyModel.Thinking = ©Thinking
|
||||
}
|
||||
return ©Model
|
||||
}
|
||||
@@ -533,6 +571,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.unregisterClientInternal(clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
}
|
||||
|
||||
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
||||
@@ -599,9 +638,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
@@ -613,9 +655,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
@@ -631,6 +675,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil {
|
||||
@@ -644,6 +689,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
||||
}
|
||||
registration.SuspendedClients[clientID] = reason
|
||||
registration.LastUpdated = time.Now()
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
if reason != "" {
|
||||
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
||||
} else {
|
||||
@@ -661,6 +707,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil || registration.SuspendedClients == nil {
|
||||
@@ -671,6 +718,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
||||
}
|
||||
delete(registration.SuspendedClients, clientID)
|
||||
registration.LastUpdated = time.Now()
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
||||
}
|
||||
|
||||
@@ -706,22 +754,52 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
|
||||
// Returns:
|
||||
// - []map[string]any: List of available models in the requested format
|
||||
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
now := time.Now()
|
||||
|
||||
models := make([]map[string]any, 0)
|
||||
r.mutex.RLock()
|
||||
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||
models := cloneModelMaps(cache.models)
|
||||
r.mutex.RUnlock()
|
||||
return models
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||
return cloneModelMaps(cache.models)
|
||||
}
|
||||
|
||||
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
|
||||
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
|
||||
models: cloneModelMaps(models),
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
||||
models := make([]map[string]any, 0, len(r.models))
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
var expiresAt time.Time
|
||||
|
||||
for _, registration := range r.models {
|
||||
// Check if model has any non-quota-exceeded clients
|
||||
availableClients := registration.Count
|
||||
now := time.Now()
|
||||
|
||||
// Count clients that have exceeded quota but haven't recovered yet
|
||||
expiredClients := 0
|
||||
for _, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime == nil {
|
||||
continue
|
||||
}
|
||||
recoveryAt := quotaTime.Add(quotaExpiredDuration)
|
||||
if now.Before(recoveryAt) {
|
||||
expiredClients++
|
||||
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
||||
expiresAt = recoveryAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -742,7 +820,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
effectiveClients = 0
|
||||
}
|
||||
|
||||
// Include models that have available clients, or those solely cooling down.
|
||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||
model := r.convertModelToMap(registration.Info, handlerType)
|
||||
if model != nil {
|
||||
@@ -751,7 +828,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
return models, expiresAt
|
||||
}
|
||||
|
||||
func cloneModelMaps(models []map[string]any) []map[string]any {
|
||||
cloned := make([]map[string]any, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
cloned = append(cloned, nil)
|
||||
continue
|
||||
}
|
||||
copyModel := make(map[string]any, len(model))
|
||||
for key, value := range model {
|
||||
copyModel[key] = cloneModelMapValue(value)
|
||||
}
|
||||
cloned = append(cloned, copyModel)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneModelMapValue(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
copyMap := make(map[string]any, len(typed))
|
||||
for key, entry := range typed {
|
||||
copyMap[key] = cloneModelMapValue(entry)
|
||||
}
|
||||
return copyMap
|
||||
case []any:
|
||||
copySlice := make([]any, len(typed))
|
||||
for i, entry := range typed {
|
||||
copySlice[i] = cloneModelMapValue(entry)
|
||||
}
|
||||
return copySlice
|
||||
case []string:
|
||||
return append([]string(nil), typed...)
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
||||
@@ -867,11 +981,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
|
||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||
if entry.info != nil {
|
||||
result = append(result, entry.info)
|
||||
result = append(result, cloneModelInfo(entry.info))
|
||||
continue
|
||||
}
|
||||
if ok && registration != nil && registration.Info != nil {
|
||||
result = append(result, registration.Info)
|
||||
result = append(result, cloneModelInfo(registration.Info))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -980,13 +1094,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
||||
if reg.Providers != nil {
|
||||
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||
return info
|
||||
return cloneModelInfo(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback to global info (last registered)
|
||||
return reg.Info
|
||||
return cloneModelInfo(reg.Info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1026,7 +1140,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
result["supported_parameters"] = model.SupportedParameters
|
||||
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if len(model.SupportedEndpoints) > 0 {
|
||||
result["supported_endpoints"] = model.SupportedEndpoints
|
||||
@@ -1087,7 +1201,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
result["outputTokenLimit"] = model.OutputTokenLimit
|
||||
}
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
||||
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedInputModalities) > 0 {
|
||||
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
|
||||
}
|
||||
if len(model.SupportedOutputModalities) > 0 {
|
||||
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -1117,15 +1237,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
invalidated := false
|
||||
|
||||
for modelID, registration := range r.models {
|
||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
invalidated = true
|
||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
if invalidated {
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||
@@ -1139,8 +1264,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
// - string: The model ID of the first available model, or empty string if none available
|
||||
// - error: An error if no models are available
|
||||
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// Get all available models for this handler type
|
||||
models := r.GetAvailableModels(handlerType)
|
||||
@@ -1200,13 +1323,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
||||
// Prefer client's own model info to preserve original type/owned_by
|
||||
if clientInfos != nil {
|
||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||
result = append(result, info)
|
||||
result = append(result, cloneModelInfo(info))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Fallback to global registry (for backwards compatibility)
|
||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||
result = append(result, reg.Info)
|
||||
result = append(result, cloneModelInfo(reg.Info))
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
54
internal/registry/model_registry_cache_test.go
Normal file
54
internal/registry/model_registry_cache_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package registry
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||
|
||||
first := r.GetAvailableModels("openai")
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(first))
|
||||
}
|
||||
first[0]["id"] = "mutated"
|
||||
first[0]["display_name"] = "Mutated"
|
||||
|
||||
second := r.GetAvailableModels("openai")
|
||||
if got := second[0]["id"]; got != "m1" {
|
||||
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
|
||||
}
|
||||
if got := second[0]["display_name"]; got != "Model One" {
|
||||
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||
|
||||
models := r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(models))
|
||||
}
|
||||
if got := models[0]["display_name"]; got != "Model One" {
|
||||
t.Fatalf("expected initial display_name Model One, got %v", got)
|
||||
}
|
||||
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
|
||||
models = r.GetAvailableModels("openai")
|
||||
if got := models[0]["display_name"]; got != "Model One Updated" {
|
||||
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
|
||||
}
|
||||
|
||||
r.SuspendClientModel("client-1", "m1", "manual")
|
||||
models = r.GetAvailableModels("openai")
|
||||
if len(models) != 0 {
|
||||
t.Fatalf("expected no available models after suspension, got %d", len(models))
|
||||
}
|
||||
|
||||
r.ResumeClientModel("client-1", "m1")
|
||||
models = r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected model to reappear after resume, got %d", len(models))
|
||||
}
|
||||
}
|
||||
149
internal/registry/model_registry_safety_test.go
Normal file
149
internal/registry/model_registry_safety_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetModelInfoReturnsClone(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetModelInfo("m1", "gemini")
|
||||
if first == nil {
|
||||
t.Fatal("expected model info")
|
||||
}
|
||||
first.DisplayName = "mutated"
|
||||
first.Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetModelInfo("m1", "gemini")
|
||||
if second.DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
|
||||
}
|
||||
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelsForClientReturnsClones(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetModelsForClient("client-1")
|
||||
if len(first) != 1 || first[0] == nil {
|
||||
t.Fatalf("expected one model, got %+v", first)
|
||||
}
|
||||
first[0].DisplayName = "mutated"
|
||||
first[0].Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetModelsForClient("client-1")
|
||||
if len(second) != 1 || second[0] == nil {
|
||||
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||
}
|
||||
if second[0].DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||
}
|
||||
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetAvailableModelsByProvider("gemini")
|
||||
if len(first) != 1 || first[0] == nil {
|
||||
t.Fatalf("expected one model, got %+v", first)
|
||||
}
|
||||
first[0].DisplayName = "mutated"
|
||||
first[0].Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetAvailableModelsByProvider("gemini")
|
||||
if len(second) != 1 || second[0] == nil {
|
||||
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||
}
|
||||
if second[0].DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||
}
|
||||
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
|
||||
r.SetModelQuotaExceeded("client-1", "m1")
|
||||
if models := r.GetAvailableModels("openai"); len(models) != 1 {
|
||||
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
|
||||
}
|
||||
|
||||
r.mutex.Lock()
|
||||
quotaTime := time.Now().Add(-6 * time.Minute)
|
||||
r.models["m1"].QuotaExceededClients["client-1"] = "aTime
|
||||
r.mutex.Unlock()
|
||||
|
||||
r.CleanupExpiredQuotas()
|
||||
|
||||
if count := r.GetModelCount("m1"); count != 1 {
|
||||
t.Fatalf("expected model count 1 after cleanup, got %d", count)
|
||||
}
|
||||
models := r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
|
||||
}
|
||||
if got := models[0]["id"]; got != "m1" {
|
||||
t.Fatalf("expected model id m1, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "openai", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
SupportedParameters: []string{"temperature", "top_p"},
|
||||
}})
|
||||
|
||||
first := r.GetAvailableModels("openai")
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("expected one model, got %d", len(first))
|
||||
}
|
||||
params, ok := first[0]["supported_parameters"].([]string)
|
||||
if !ok || len(params) != 2 {
|
||||
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
|
||||
}
|
||||
params[0] = "mutated"
|
||||
|
||||
second := r.GetAvailableModels("openai")
|
||||
params, ok = second[0]["supported_parameters"].([]string)
|
||||
if !ok || len(params) != 2 || params[0] != "temperature" {
|
||||
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
|
||||
first := LookupModelInfo("glm-4.6")
|
||||
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
|
||||
t.Fatalf("expected static model with thinking levels, got %+v", first)
|
||||
}
|
||||
first.Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := LookupModelInfo("glm-4.6")
|
||||
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
|
||||
t.Fatalf("expected static lookup clone, got %+v", second)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -45,10 +46,10 @@ const (
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
||||
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -142,6 +143,62 @@ func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
|
||||
return &AntigravityExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// antigravityTransport is a singleton HTTP/1.1 transport shared by all Antigravity requests.
|
||||
// It is initialized once via antigravityTransportOnce to avoid leaking a new connection pool
|
||||
// (and the goroutines managing it) on every request.
|
||||
var (
|
||||
antigravityTransport *http.Transport
|
||||
antigravityTransportOnce sync.Once
|
||||
)
|
||||
|
||||
func cloneTransportWithHTTP11(base *http.Transport) *http.Transport {
|
||||
if base == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := base.Clone()
|
||||
clone.ForceAttemptHTTP2 = false
|
||||
// Wipe TLSNextProto to prevent implicit HTTP/2 upgrade.
|
||||
clone.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
|
||||
if clone.TLSClientConfig == nil {
|
||||
clone.TLSClientConfig = &tls.Config{}
|
||||
} else {
|
||||
clone.TLSClientConfig = clone.TLSClientConfig.Clone()
|
||||
}
|
||||
// Actively advertise only HTTP/1.1 in the ALPN handshake.
|
||||
clone.TLSClientConfig.NextProtos = []string{"http/1.1"}
|
||||
return clone
|
||||
}
|
||||
|
||||
// initAntigravityTransport creates the shared HTTP/1.1 transport exactly once.
|
||||
func initAntigravityTransport() {
|
||||
base, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
base = &http.Transport{}
|
||||
}
|
||||
antigravityTransport = cloneTransportWithHTTP11(base)
|
||||
}
|
||||
|
||||
// newAntigravityHTTPClient creates an HTTP client specifically for Antigravity,
|
||||
// enforcing HTTP/1.1 by disabling HTTP/2 to perfectly mimic Node.js https defaults.
|
||||
// The underlying Transport is a singleton to avoid leaking connection pools.
|
||||
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
antigravityTransportOnce.Do(initAntigravityTransport)
|
||||
|
||||
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
// If no transport is set, use the shared HTTP/1.1 transport.
|
||||
if client.Transport == nil {
|
||||
client.Transport = antigravityTransport
|
||||
return client
|
||||
}
|
||||
|
||||
// Preserve proxy settings from proxy-aware transports while forcing HTTP/1.1.
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
client.Transport = cloneTransportWithHTTP11(transport)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
|
||||
|
||||
@@ -162,6 +219,8 @@ func (e *AntigravityExecutor) PrepareRequest(req *http.Request, auth *cliproxyau
|
||||
}
|
||||
|
||||
// HttpRequest injects Antigravity credentials into the request and executes it.
|
||||
// It uses a whitelist approach: all incoming headers are stripped and only
|
||||
// the minimum set required by the Antigravity protocol is explicitly set.
|
||||
func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("antigravity executor: request is nil")
|
||||
@@ -170,10 +229,29 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
|
||||
// --- Whitelist: save only the headers we need from the original request ---
|
||||
contentType := httpReq.Header.Get("Content-Type")
|
||||
|
||||
// Wipe ALL incoming headers
|
||||
for k := range httpReq.Header {
|
||||
delete(httpReq.Header, k)
|
||||
}
|
||||
|
||||
// --- Set only the headers Antigravity actually sends ---
|
||||
if contentType != "" {
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
// Content-Length is managed automatically by Go's http.Client from the Body
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
httpReq.Close = true // sends Connection: close
|
||||
|
||||
// Inject Authorization: Bearer <token>
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -185,7 +263,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
isClaude := strings.Contains(strings.ToLower(baseModel), "claude")
|
||||
|
||||
if isClaude || strings.Contains(baseModel, "gemini-3-pro") {
|
||||
if isClaude || strings.Contains(baseModel, "gemini-3-pro") || strings.Contains(baseModel, "gemini-3.1-flash-image") {
|
||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
@@ -220,7 +298,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||
|
||||
@@ -362,7 +440,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||
|
||||
@@ -754,7 +832,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||
|
||||
@@ -956,7 +1034,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -987,10 +1065,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if errReq != nil {
|
||||
return cliproxyexecutor.Response{}, errReq
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
@@ -1084,14 +1162,26 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
}
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
|
||||
var payload []byte
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||
}
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
payload = []byte(`{}`)
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
@@ -1152,7 +1242,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
continue
|
||||
}
|
||||
switch modelID {
|
||||
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
|
||||
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||
continue
|
||||
}
|
||||
modelCfg := modelConfig[modelID]
|
||||
@@ -1174,6 +1264,29 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
OwnedBy: antigravityAuthType,
|
||||
Type: antigravityAuthType,
|
||||
}
|
||||
|
||||
// Build input modalities from upstream capability flags.
|
||||
inputModalities := []string{"TEXT"}
|
||||
if modelData.Get("supportsImages").Bool() {
|
||||
inputModalities = append(inputModalities, "IMAGE")
|
||||
}
|
||||
if modelData.Get("supportsVideo").Bool() {
|
||||
inputModalities = append(inputModalities, "VIDEO")
|
||||
}
|
||||
modelInfo.SupportedInputModalities = inputModalities
|
||||
modelInfo.SupportedOutputModalities = []string{"TEXT"}
|
||||
|
||||
// Token limits from upstream.
|
||||
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||
modelInfo.InputTokenLimit = int(maxTok)
|
||||
}
|
||||
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||
modelInfo.OutputTokenLimit = int(maxOut)
|
||||
}
|
||||
|
||||
// Supported generation methods (Gemini v1beta convention).
|
||||
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
|
||||
|
||||
// Look up Thinking support from static config using upstream model name.
|
||||
if modelCfg != nil {
|
||||
if modelCfg.Thinking != nil {
|
||||
@@ -1241,10 +1354,11 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
|
||||
return auth, errReq
|
||||
}
|
||||
httpReq.Header.Set("Host", "oauth2.googleapis.com")
|
||||
httpReq.Header.Set("User-Agent", defaultAntigravityAgent)
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
// Real Antigravity uses Go's default User-Agent for OAuth token refresh
|
||||
httpReq.Header.Set("User-Agent", "Go-http-client/2.0")
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
return auth, errDo
|
||||
@@ -1315,7 +1429,7 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au
|
||||
return nil
|
||||
}
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
|
||||
if errFetch != nil {
|
||||
return errFetch
|
||||
@@ -1369,7 +1483,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high")
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro")
|
||||
payloadStr := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||
@@ -1383,18 +1497,18 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||
}
|
||||
|
||||
if useAntigravitySchema {
|
||||
systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
// if useAntigravitySchema {
|
||||
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
|
||||
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
// for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
@@ -1406,14 +1520,10 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
if stream {
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
@@ -1625,7 +1735,16 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
||||
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||
template, _ = sjson.Set(template, "requestType", "agent")
|
||||
|
||||
isImageModel := strings.Contains(modelName, "image")
|
||||
|
||||
var reqType string
|
||||
if isImageModel {
|
||||
reqType = "image_gen"
|
||||
} else {
|
||||
reqType = "agent"
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestType", reqType)
|
||||
|
||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||
if projectID != "" {
|
||||
@@ -1633,8 +1752,13 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
|
||||
if isImageModel {
|
||||
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
}
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
||||
@@ -1648,6 +1772,10 @@ func generateRequestID() string {
|
||||
return "agent-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func generateImageGenRequestID() string {
|
||||
return fmt.Sprintf("image_gen/%d/%s/12", time.Now().UnixMilli(), uuid.NewString())
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
randSourceMutex.Lock()
|
||||
n := randSource.Int63n(9_000_000_000_000_000_000)
|
||||
|
||||
@@ -59,6 +59,7 @@ func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"deprecated": true,
|
||||
"enum": ["a", "b"],
|
||||
"enumTitles": ["A", "B"]
|
||||
}
|
||||
@@ -156,4 +157,7 @@ func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]a
|
||||
if _, ok := mode["enumTitles"]; ok {
|
||||
t.Fatalf("enumTitles should be removed from nested schema")
|
||||
}
|
||||
if _, ok := mode["deprecated"]; ok {
|
||||
t.Fatalf("deprecated should be removed from nested schema")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,14 @@ import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -36,7 +41,9 @@ type ClaudeExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
const claudeToolPrefix = "proxy_"
|
||||
// claudeToolPrefix is empty to match real Claude Code behavior (no tool name prefix).
|
||||
// Previously "proxy_" was used but this is a detectable fingerprint difference.
|
||||
const claudeToolPrefix = ""
|
||||
|
||||
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
||||
|
||||
@@ -130,6 +137,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
body = ensureCacheControl(body)
|
||||
}
|
||||
|
||||
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
|
||||
// Cloaking and ensureCacheControl may push the total over 4 when the client
|
||||
// (e.g. Amp CLI) already sends multiple cache_control blocks.
|
||||
body = enforceCacheControlLimit(body, 4)
|
||||
|
||||
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
|
||||
// A 1h-TTL block must not appear after a 5m-TTL block in evaluation order (tools→system→messages).
|
||||
body = normalizeCacheControlTTL(body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
@@ -171,11 +187,27 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
return resp, err
|
||||
@@ -271,6 +303,12 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
body = ensureCacheControl(body)
|
||||
}
|
||||
|
||||
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
|
||||
body = enforceCacheControlLimit(body, 4)
|
||||
|
||||
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
|
||||
body = normalizeCacheControlTTL(body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
@@ -312,10 +350,26 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
@@ -420,6 +474,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
|
||||
// Keep count_tokens requests compatible with Anthropic cache-control constraints too.
|
||||
body = enforceCacheControlLimit(body, 4)
|
||||
body = normalizeCacheControlTTL(body)
|
||||
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
@@ -459,9 +517,25 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)}
|
||||
@@ -554,6 +628,12 @@ func disableThinkingIfToolChoiceForced(body []byte) []byte {
|
||||
if toolChoiceType == "any" || toolChoiceType == "tool" {
|
||||
// Remove thinking configuration entirely to avoid API error
|
||||
body, _ = sjson.DeleteBytes(body, "thinking")
|
||||
// Adaptive thinking may also set output_config.effort; remove it to avoid
|
||||
// leaking thinking controls when tool_choice forces tool use.
|
||||
body, _ = sjson.DeleteBytes(body, "output_config.effort")
|
||||
if oc := gjson.GetBytes(body, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
body, _ = sjson.DeleteBytes(body, "output_config")
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
@@ -576,12 +656,61 @@ func (c *compositeReadCloser) Close() error {
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// peekableBody wraps a bufio.Reader around the original ReadCloser so that
|
||||
// magic bytes can be inspected without consuming them from the stream.
|
||||
type peekableBody struct {
|
||||
*bufio.Reader
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func (p *peekableBody) Close() error {
|
||||
return p.closer.Close()
|
||||
}
|
||||
|
||||
func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) {
|
||||
if body == nil {
|
||||
return nil, fmt.Errorf("response body is nil")
|
||||
}
|
||||
if contentEncoding == "" {
|
||||
return body, nil
|
||||
// No Content-Encoding header. Attempt best-effort magic-byte detection to
|
||||
// handle misbehaving upstreams that compress without setting the header.
|
||||
// Only gzip (1f 8b) and zstd (28 b5 2f fd) have reliable magic sequences;
|
||||
// br and deflate have none and are left as-is.
|
||||
// The bufio wrapper preserves unread bytes so callers always see the full
|
||||
// stream regardless of whether decompression was applied.
|
||||
pb := &peekableBody{Reader: bufio.NewReader(body), closer: body}
|
||||
magic, peekErr := pb.Peek(4)
|
||||
if peekErr == nil || (peekErr == io.EOF && len(magic) >= 2) {
|
||||
switch {
|
||||
case len(magic) >= 2 && magic[0] == 0x1f && magic[1] == 0x8b:
|
||||
gzipReader, gzErr := gzip.NewReader(pb)
|
||||
if gzErr != nil {
|
||||
_ = pb.Close()
|
||||
return nil, fmt.Errorf("magic-byte gzip: failed to create reader: %w", gzErr)
|
||||
}
|
||||
return &compositeReadCloser{
|
||||
Reader: gzipReader,
|
||||
closers: []func() error{
|
||||
gzipReader.Close,
|
||||
pb.Close,
|
||||
},
|
||||
}, nil
|
||||
case len(magic) >= 4 && magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd:
|
||||
decoder, zdErr := zstd.NewReader(pb)
|
||||
if zdErr != nil {
|
||||
_ = pb.Close()
|
||||
return nil, fmt.Errorf("magic-byte zstd: failed to create reader: %w", zdErr)
|
||||
}
|
||||
return &compositeReadCloser{
|
||||
Reader: decoder,
|
||||
closers: []func() error{
|
||||
func() error { decoder.Close(); return nil },
|
||||
pb.Close,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return pb, nil
|
||||
}
|
||||
encodings := strings.Split(contentEncoding, ",")
|
||||
for _, raw := range encodings {
|
||||
@@ -696,23 +825,29 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
promptCachingBeta := "prompt-caching-2024-07-31"
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14," + promptCachingBeta
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
baseBetas = val
|
||||
if !strings.Contains(val, "oauth") {
|
||||
baseBetas += ",oauth-2025-04-20"
|
||||
}
|
||||
}
|
||||
if !strings.Contains(baseBetas, promptCachingBeta) {
|
||||
baseBetas += "," + promptCachingBeta
|
||||
|
||||
hasClaude1MHeader := false
|
||||
if ginHeaders != nil {
|
||||
if _, ok := ginHeaders[textproto.CanonicalMIMEHeaderKey("X-CPA-CLAUDE-1M")]; ok {
|
||||
hasClaude1MHeader = true
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra betas from request body
|
||||
if len(extraBetas) > 0 {
|
||||
// Merge extra betas from request body and request flags.
|
||||
if len(extraBetas) > 0 || hasClaude1MHeader {
|
||||
existingSet := make(map[string]bool)
|
||||
for _, b := range strings.Split(baseBetas, ",") {
|
||||
existingSet[strings.TrimSpace(b)] = true
|
||||
betaName := strings.TrimSpace(b)
|
||||
if betaName != "" {
|
||||
existingSet[betaName] = true
|
||||
}
|
||||
}
|
||||
for _, beta := range extraBetas {
|
||||
beta = strings.TrimSpace(beta)
|
||||
@@ -721,14 +856,16 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
existingSet[beta] = true
|
||||
}
|
||||
}
|
||||
if hasClaude1MHeader && !existingSet["context-1m-2025-08-07"] {
|
||||
baseBetas += ",context-1m-2025-08-07"
|
||||
}
|
||||
}
|
||||
r.Header.Set("Anthropic-Beta", baseBetas)
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
||||
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
@@ -737,13 +874,28 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
|
||||
// For User-Agent, only forward the client's header if it's already a Claude Code client.
|
||||
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
|
||||
// to avoid leaking the real client identity during cloaking.
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = ginHeaders.Get("User-Agent")
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
} else {
|
||||
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
|
||||
}
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
// SSE streams must not be compressed: the downstream scanner reads
|
||||
// line-delimited text and cannot parse compressed bytes. Using
|
||||
// "identity" tells the upstream to send an uncompressed stream.
|
||||
r.Header.Set("Accept-Encoding", "identity")
|
||||
} else {
|
||||
r.Header.Set("Accept", "application/json")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
@@ -752,6 +904,12 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||
if stream {
|
||||
r.Header.Set("Accept-Encoding", "identity")
|
||||
}
|
||||
}
|
||||
|
||||
func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
@@ -771,22 +929,7 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
}
|
||||
|
||||
func checkSystemInstructions(payload []byte) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
||||
if system.IsArray() {
|
||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
} else {
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
return payload
|
||||
return checkSystemInstructionsWithMode(payload, false)
|
||||
}
|
||||
|
||||
func isClaudeOAuthToken(apiKey string) bool {
|
||||
@@ -1060,33 +1203,73 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||
return payload
|
||||
}
|
||||
|
||||
// checkSystemInstructionsWithMode injects Claude Code system prompt.
|
||||
// In strict mode, it replaces all user system messages.
|
||||
// In non-strict mode (default), it prepends to existing system messages.
|
||||
// generateBillingHeader creates the x-anthropic-billing-header text block that
|
||||
// real Claude Code prepends to every system prompt array.
|
||||
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=cli; cch=<hash>;
|
||||
func generateBillingHeader(payload []byte) string {
|
||||
// Generate a deterministic cch hash from the payload content (system + messages + tools).
|
||||
// Real Claude Code uses a 5-char hex hash that varies per request.
|
||||
h := sha256.Sum256(payload)
|
||||
cch := hex.EncodeToString(h[:])[:5]
|
||||
|
||||
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
|
||||
buildBytes := make([]byte, 2)
|
||||
_, _ = rand.Read(buildBytes)
|
||||
buildHash := hex.EncodeToString(buildBytes)[:3]
|
||||
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
|
||||
}
|
||||
|
||||
// checkSystemInstructionsWithMode injects Claude Code-style system blocks:
|
||||
//
|
||||
// system[0]: billing header (no cache_control)
|
||||
// system[1]: agent identifier (no cache_control)
|
||||
// system[2..]: user system messages (cache_control added when missing)
|
||||
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
||||
|
||||
billingText := generateBillingHeader(payload)
|
||||
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
|
||||
// No cache_control on the agent block. It is a cloaking artifact with zero cache
|
||||
// value (the last system block is what actually triggers caching of all system content).
|
||||
// Including any cache_control here creates an intra-system TTL ordering violation
|
||||
// when the client's system blocks use ttl='1h' (prompt-caching-scope-2026-01-05 beta
|
||||
// forbids 1h blocks after 5m blocks, and a no-TTL block defaults to 5m).
|
||||
agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK."}`
|
||||
|
||||
if strictMode {
|
||||
// Strict mode: replace all system messages with Claude Code prompt only
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
// Strict mode: billing header + agent identifier only
|
||||
result := "[" + billingBlock + "," + agentBlock + "]"
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
|
||||
return payload
|
||||
}
|
||||
|
||||
// Non-strict mode (default): prepend Claude Code prompt to existing system messages
|
||||
if system.IsArray() {
|
||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
} else {
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
// Non-strict mode: billing header + agent identifier + user system messages
|
||||
// Skip if already injected
|
||||
firstText := gjson.GetBytes(payload, "system.0.text").String()
|
||||
if strings.HasPrefix(firstText, "x-anthropic-billing-header:") {
|
||||
return payload
|
||||
}
|
||||
|
||||
result := "[" + billingBlock + "," + agentBlock
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
// Add cache_control to user system messages if not present.
|
||||
// Do NOT add ttl — let it inherit the default (5m) to avoid
|
||||
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
|
||||
partJSON := part.Raw
|
||||
if !part.Get("cache_control").Exists() {
|
||||
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
|
||||
}
|
||||
result += "," + partJSON
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
result += "]"
|
||||
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
|
||||
return payload
|
||||
}
|
||||
|
||||
@@ -1224,6 +1407,325 @@ func countCacheControls(payload []byte) int {
|
||||
return count
|
||||
}
|
||||
|
||||
func parsePayloadObject(payload []byte) (map[string]any, bool) {
|
||||
if len(payload) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var root map[string]any
|
||||
if err := json.Unmarshal(payload, &root); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return root, true
|
||||
}
|
||||
|
||||
func marshalPayloadObject(original []byte, root map[string]any) []byte {
|
||||
if root == nil {
|
||||
return original
|
||||
}
|
||||
out, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return original
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asObject(v any) (map[string]any, bool) {
|
||||
obj, ok := v.(map[string]any)
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
func asArray(v any) ([]any, bool) {
|
||||
arr, ok := v.([]any)
|
||||
return arr, ok
|
||||
}
|
||||
|
||||
func countCacheControlsMap(root map[string]any) int {
|
||||
count := 0
|
||||
|
||||
if system, ok := asArray(root["system"]); ok {
|
||||
for _, item := range system {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tools, ok := asArray(root["tools"]); ok {
|
||||
for _, item := range tools {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messages, ok := asArray(root["messages"]); ok {
|
||||
for _, msg := range messages {
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
|
||||
ccRaw, exists := obj["cache_control"]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
cc, ok := asObject(ccRaw)
|
||||
if !ok {
|
||||
*seen5m = true
|
||||
return false
|
||||
}
|
||||
ttlRaw, ttlExists := cc["ttl"]
|
||||
ttl, ttlIsString := ttlRaw.(string)
|
||||
if !ttlExists || !ttlIsString || ttl != "1h" {
|
||||
*seen5m = true
|
||||
return false
|
||||
}
|
||||
if *seen5m {
|
||||
delete(cc, "ttl")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func findLastCacheControlIndex(arr []any) int {
|
||||
last := -1
|
||||
for idx, item := range arr {
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
last = idx
|
||||
}
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
|
||||
for idx, item := range arr {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripAllCacheControl(arr []any, excess *int) {
|
||||
for _, item := range arr {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripMessageCacheControl(messages []any, excess *int) {
|
||||
for _, msg := range messages {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
|
||||
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
|
||||
// appear after a 5m-TTL block anywhere in the evaluation order.
|
||||
//
|
||||
// Anthropic evaluates blocks in order: tools → system (index 0..N) → messages.
|
||||
// Within each section, blocks are evaluated in array order. A 5m (default) block
|
||||
// followed by a 1h block at ANY later position is an error — including within
|
||||
// the same section (e.g. system[1]=5m then system[3]=1h).
|
||||
//
|
||||
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
|
||||
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
|
||||
func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
root, ok := parsePayloadObject(payload)
|
||||
if !ok {
|
||||
return payload
|
||||
}
|
||||
|
||||
seen5m := false
|
||||
modified := false
|
||||
|
||||
if tools, ok := asArray(root["tools"]); ok {
|
||||
for _, tool := range tools {
|
||||
if obj, ok := asObject(tool); ok {
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if system, ok := asArray(root["system"]); ok {
|
||||
for _, item := range system {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messages, ok := asArray(root["messages"]); ok {
|
||||
for _, msg := range messages {
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return payload
|
||||
}
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
// enforceCacheControlLimit removes excess cache_control blocks from a payload
|
||||
// so the total does not exceed the Anthropic API limit (currently 4).
|
||||
//
|
||||
// Anthropic evaluates cache breakpoints in order: tools → system → messages.
|
||||
// The most valuable breakpoints are:
|
||||
// 1. Last tool — caches ALL tool definitions
|
||||
// 2. Last system block — caches ALL system content
|
||||
// 3. Recent messages — cache conversation context
|
||||
//
|
||||
// Removal priority (strip lowest-value first):
|
||||
//
|
||||
// Phase 1: system blocks earliest-first, preserving the last one.
|
||||
// Phase 2: tool blocks earliest-first, preserving the last one.
|
||||
// Phase 3: message content blocks earliest-first.
|
||||
// Phase 4: remaining system blocks (last system).
|
||||
// Phase 5: remaining tool blocks (last tool).
|
||||
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
|
||||
root, ok := parsePayloadObject(payload)
|
||||
if !ok {
|
||||
return payload
|
||||
}
|
||||
|
||||
total := countCacheControlsMap(root)
|
||||
if total <= maxBlocks {
|
||||
return payload
|
||||
}
|
||||
|
||||
excess := total - maxBlocks
|
||||
|
||||
var system []any
|
||||
if arr, ok := asArray(root["system"]); ok {
|
||||
system = arr
|
||||
}
|
||||
var tools []any
|
||||
if arr, ok := asArray(root["tools"]); ok {
|
||||
tools = arr
|
||||
}
|
||||
var messages []any
|
||||
if arr, ok := asArray(root["messages"]); ok {
|
||||
messages = arr
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
stripMessageCacheControl(messages, &excess)
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
stripAllCacheControl(system, &excess)
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
stripAllCacheControl(tools, &excess)
|
||||
}
|
||||
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
|
||||
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
|
||||
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.
|
||||
|
||||
@@ -2,12 +2,15 @@ package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -348,3 +351,632 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
||||
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],
|
||||
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||
}`)
|
||||
|
||||
out := normalizeCacheControlTTL(payload)
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" {
|
||||
t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h")
|
||||
}
|
||||
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
|
||||
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
|
||||
// Payload where no TTL normalization is needed (all blocks use 1h with no
|
||||
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
|
||||
// that json.Marshal would escape to \u003c etc., altering byte identity.
|
||||
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"<system-reminder>foo & bar</system-reminder>","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
|
||||
out := normalizeCacheControlTTL(payload)
|
||||
|
||||
if !bytes.Equal(out, payload) {
|
||||
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}}
|
||||
],
|
||||
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
|
||||
"messages": [
|
||||
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]},
|
||||
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := enforceCacheControlLimit(payload, 4)
|
||||
|
||||
if got := countCacheControls(out); got != 4 {
|
||||
t.Fatalf("cache_control count = %d, want 4", got)
|
||||
}
|
||||
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
|
||||
}
|
||||
if !gjson.GetBytes(out, "tools.1.cache_control").Exists() {
|
||||
t.Fatalf("tools.1.cache_control (last tool) should be preserved")
|
||||
}
|
||||
if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() {
|
||||
t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t3","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t4","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t5","cache_control":{"type":"ephemeral"}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := enforceCacheControlLimit(payload, 4)
|
||||
|
||||
if got := countCacheControls(out); got != 4 {
|
||||
t.Fatalf("cache_control count = %d, want 4", got)
|
||||
}
|
||||
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||
t.Fatalf("tools.0.cache_control should be removed to satisfy max=4")
|
||||
}
|
||||
if !gjson.GetBytes(out, "tools.4.cache_control").Exists() {
|
||||
t.Fatalf("last tool cache_control should be preserved when possible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) {
|
||||
var seenBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
seenBody = bytes.Clone(body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"input_tokens":42}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}}
|
||||
],
|
||||
"system": [
|
||||
{"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}},
|
||||
{"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}}
|
||||
],
|
||||
"messages": [
|
||||
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
|
||||
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-haiku-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens error: %v", err)
|
||||
}
|
||||
|
||||
if len(seenBody) == 0 {
|
||||
t.Fatal("expected count_tokens request body to be captured")
|
||||
}
|
||||
if got := countCacheControls(seenBody); got > 4 {
|
||||
t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got)
|
||||
}
|
||||
if hasTTLOrderingViolation(seenBody) {
|
||||
t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody))
|
||||
}
|
||||
}
|
||||
|
||||
func hasTTLOrderingViolation(payload []byte) bool {
|
||||
seen5m := false
|
||||
violates := false
|
||||
|
||||
checkCC := func(cc gjson.Result) {
|
||||
if !cc.Exists() || violates {
|
||||
return
|
||||
}
|
||||
ttl := cc.Get("ttl").String()
|
||||
if ttl != "1h" {
|
||||
seen5m = true
|
||||
return
|
||||
}
|
||||
if seen5m {
|
||||
violates = true
|
||||
}
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(payload, "tools")
|
||||
if tools.IsArray() {
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
checkCC(tool.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, item gjson.Result) bool {
|
||||
checkCC(item.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if messages.IsArray() {
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
content := msg.Get("content")
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, item gjson.Result) bool {
|
||||
checkCC(item.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
return violates
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_Execute_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteStream_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_CountTokens_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func testClaudeExecutorInvalidCompressedErrorBody(
|
||||
t *testing.T,
|
||||
invoke func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("not-a-valid-gzip-stream"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
err := invoke(executor, auth, payload)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to decode error response body") {
|
||||
t.Fatalf("expected decode failure message, got: %v", err)
|
||||
}
|
||||
if statusProvider, ok := err.(interface{ StatusCode() int }); !ok || statusProvider.StatusCode() != http.StatusBadRequest {
|
||||
t.Fatalf("expected status code 400, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
||||
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
||||
// compressed SSE body that would silently break the line scanner.
|
||||
func TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding(t *testing.T) {
|
||||
var gotEncoding, gotAccept string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||
}
|
||||
}
|
||||
|
||||
if gotEncoding != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "identity")
|
||||
}
|
||||
if gotAccept != "text/event-stream" {
|
||||
t.Errorf("Accept = %q, want %q", gotAccept, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding verifies that non-streaming
|
||||
// requests keep the full accept-encoding to allow response compression (which
|
||||
// decodeResponseBody handles correctly).
|
||||
func TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding(t *testing.T) {
|
||||
var gotEncoding, gotAccept string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet-20241022","role":"assistant","content":[{"type":"text","text":"hi"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
|
||||
if gotEncoding != "gzip, deflate, br, zstd" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "gzip, deflate, br, zstd")
|
||||
}
|
||||
if gotAccept != "application/json" {
|
||||
t.Errorf("Accept = %q, want %q", gotAccept, "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded verifies that a streaming
|
||||
// HTTP 200 response with Content-Encoding: gzip is correctly decompressed before
|
||||
// the line scanner runs, so SSE chunks are not silently dropped.
|
||||
func TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
|
||||
_ = gz.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
_, _ = w.Write(compressedBody)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
|
||||
var combined strings.Builder
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("chunk error: %v", chunk.Err)
|
||||
}
|
||||
combined.Write(chunk.Payload)
|
||||
}
|
||||
|
||||
if combined.Len() == 0 {
|
||||
t.Fatal("expected at least one chunk from gzip-encoded SSE body, got none (body was not decompressed)")
|
||||
}
|
||||
if !strings.Contains(combined.String(), "message_stop") {
|
||||
t.Errorf("expected SSE content in chunks, got: %q", combined.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_MagicByteGzipNoHeader verifies that decodeResponseBody
|
||||
// detects gzip-compressed content via magic bytes even when Content-Encoding is absent.
|
||||
func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
|
||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte(plaintext))
|
||||
_ = gz.Close()
|
||||
|
||||
rc := io.NopCloser(&buf)
|
||||
decoded, err := decodeResponseBody(rc, "")
|
||||
if err != nil {
|
||||
t.Fatalf("decodeResponseBody error: %v", err)
|
||||
}
|
||||
defer decoded.Close()
|
||||
|
||||
got, err := io.ReadAll(decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if string(got) != plaintext {
|
||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
|
||||
// plain text untouched when Content-Encoding is absent and no magic bytes match.
|
||||
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
|
||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||
rc := io.NopCloser(strings.NewReader(plaintext))
|
||||
decoded, err := decodeResponseBody(rc, "")
|
||||
if err != nil {
|
||||
t.Fatalf("decodeResponseBody error: %v", err)
|
||||
}
|
||||
defer decoded.Close()
|
||||
|
||||
got, err := io.ReadAll(decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if string(got) != plaintext {
|
||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader verifies the full
|
||||
// pipeline: when the upstream returns a gzip-compressed SSE body WITHOUT setting
|
||||
// Content-Encoding (a misbehaving upstream), the magic-byte sniff in
|
||||
// decodeResponseBody still decompresses it, so chunks reach the caller.
|
||||
func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
|
||||
_ = gz.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||
_, _ = w.Write(compressedBody)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
|
||||
var combined strings.Builder
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("chunk error: %v", chunk.Err)
|
||||
}
|
||||
combined.Write(chunk.Payload)
|
||||
}
|
||||
|
||||
if combined.Len() == 0 {
|
||||
t.Fatal("expected chunks from gzip body without Content-Encoding header, got none (magic-byte sniff failed)")
|
||||
}
|
||||
if !strings.Contains(combined.String(), "message_stop") {
|
||||
t.Errorf("unexpected chunk content: %q", combined.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
|
||||
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
|
||||
// path's enforced identity encoding.
|
||||
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
|
||||
var gotEncoding string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
// Inject Accept-Encoding via the custom header attribute mechanism.
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||
}
|
||||
}
|
||||
|
||||
if gotEncoding != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
|
||||
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
|
||||
// Content-Encoding is absent.
|
||||
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
|
||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd.NewWriter: %v", err)
|
||||
}
|
||||
_, _ = enc.Write([]byte(plaintext))
|
||||
_ = enc.Close()
|
||||
|
||||
rc := io.NopCloser(&buf)
|
||||
decoded, err := decodeResponseBody(rc, "")
|
||||
if err != nil {
|
||||
t.Fatalf("decodeResponseBody error: %v", err)
|
||||
}
|
||||
defer decoded.Close()
|
||||
|
||||
got, err := io.ReadAll(decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if string(got) != plaintext {
|
||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
|
||||
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
|
||||
// the Content-Encoding header. This closes the gap left by PR #1771, which only
|
||||
// fixed header-declared compression on the error path.
|
||||
func TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
|
||||
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"test error"}}`
|
||||
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte(errJSON))
|
||||
_ = gz.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write(compressedBody)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for 400 response, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "test error") {
|
||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader verifies
|
||||
// the same for the streaming executor: 4xx gzip body without Content-Encoding is
|
||||
// decoded and the error message is readable.
|
||||
func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
|
||||
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"stream test error"}}`
|
||||
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte(errJSON))
|
||||
_ = gz.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write(compressedBody)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for 400 response, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "stream test error") {
|
||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
|
||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||
// userIDPattern matches Claude Code format: user_[64-hex]_account_[uuid]_session_[uuid]
|
||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}_session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||
|
||||
// generateFakeUserID generates a fake user ID in Claude Code format.
|
||||
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
|
||||
// Format: user_[64-hex-chars]_account_[UUID-v4]_session_[UUID-v4]
|
||||
func generateFakeUserID() string {
|
||||
hexBytes := make([]byte, 32)
|
||||
_, _ = rand.Read(hexBytes)
|
||||
hexPart := hex.EncodeToString(hexBytes)
|
||||
uuidPart := uuid.New().String()
|
||||
return "user_" + hexPart + "_account__session_" + uuidPart
|
||||
accountUUID := uuid.New().String()
|
||||
sessionUUID := uuid.New().String()
|
||||
return "user_" + hexPart + "_account_" + accountUUID + "_session_" + sessionUUID
|
||||
}
|
||||
|
||||
// isValidUserID checks if a user ID matches Claude Code format.
|
||||
|
||||
@@ -616,6 +616,10 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
if promptCacheKey.Exists() {
|
||||
cache.ID = promptCacheKey.String()
|
||||
}
|
||||
} else if from == "openai" {
|
||||
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
|
||||
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
||||
}
|
||||
}
|
||||
|
||||
if cache.ID != "" {
|
||||
|
||||
64
internal/runtime/executor/codex_executor_cache_test.go
Normal file
64
internal/runtime/executor/codex_executor_cache_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFromAPIKey(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginCtx.Set("apiKey", "test-api-key")
|
||||
|
||||
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||
executor := &CodexExecutor{}
|
||||
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true}`)
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.3-codex",
|
||||
Payload: []byte(`{"model":"gpt-5.3-codex"}`),
|
||||
}
|
||||
url := "https://example.com/responses"
|
||||
|
||||
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, errRead := io.ReadAll(httpReq.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read request body: %v", errRead)
|
||||
}
|
||||
|
||||
expectedKey := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String()
|
||||
gotKey := gjson.GetBytes(body, "prompt_cache_key").String()
|
||||
if gotKey != expectedKey {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
||||
}
|
||||
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
|
||||
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
|
||||
}
|
||||
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
||||
}
|
||||
|
||||
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error (second call): %v", err)
|
||||
}
|
||||
body2, errRead2 := io.ReadAll(httpReq2.Body)
|
||||
if errRead2 != nil {
|
||||
t.Fatalf("read request body (second call): %v", errRead2)
|
||||
}
|
||||
gotKey2 := gjson.GetBytes(body2, "prompt_cache_key").String()
|
||||
if gotKey2 != expectedKey {
|
||||
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04"
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
|
||||
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
|
||||
codexResponsesWebsocketHandshakeTO = 30 * time.Second
|
||||
)
|
||||
@@ -57,11 +57,6 @@ type codexWebsocketSession struct {
|
||||
wsURL string
|
||||
authID string
|
||||
|
||||
// connCreateSent tracks whether a `response.create` message has been successfully sent
|
||||
// on the current websocket connection. The upstream expects the first message on each
|
||||
// connection to be `response.create`.
|
||||
connCreateSent bool
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
activeMu sync.Mutex
|
||||
@@ -212,13 +207,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
defer sess.reqMu.Unlock()
|
||||
}
|
||||
|
||||
allowAppend := true
|
||||
if sess != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -280,10 +269,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
// execution session.
|
||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if errDialRetry == nil && connRetry != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -312,7 +298,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
return resp, errSend
|
||||
}
|
||||
}
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
for {
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
@@ -403,26 +388,20 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
|
||||
executionSessionID := executionSessionIDFromOptions(opts)
|
||||
var sess *codexWebsocketSession
|
||||
if executionSessionID != "" {
|
||||
sess = e.getOrCreateSession(executionSessionID)
|
||||
sess.reqMu.Lock()
|
||||
if sess != nil {
|
||||
sess.reqMu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
allowAppend := true
|
||||
if sess != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -483,10 +462,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
sess.reqMu.Unlock()
|
||||
return nil, errDialRetry
|
||||
}
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -515,7 +491,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
return nil, errSend
|
||||
}
|
||||
}
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
@@ -657,31 +632,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con
|
||||
return conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
|
||||
func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte {
|
||||
func buildCodexWebsocketRequestBody(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns.
|
||||
// The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation).
|
||||
// Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive.
|
||||
//
|
||||
// NOTE: The upstream expects the first websocket event on each connection to be `response.create`,
|
||||
// so we only use `response.append` after we have initialized the current connection.
|
||||
if allowAppend {
|
||||
if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" {
|
||||
inputNode := gjson.GetBytes(body, "input")
|
||||
wsReqBody := []byte(`{}`)
|
||||
wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append")
|
||||
if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" {
|
||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw))
|
||||
return wsReqBody
|
||||
}
|
||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]"))
|
||||
return wsReqBody
|
||||
}
|
||||
}
|
||||
|
||||
// Match codex-rs websocket v2 semantics: every request is `response.create`.
|
||||
// Incremental follow-up turns continue on the same websocket using
|
||||
// `previous_response_id` + incremental `input`, not `response.append`.
|
||||
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
|
||||
if errSet == nil && len(wsReqBody) > 0 {
|
||||
return wsReqBody
|
||||
@@ -725,21 +683,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession,
|
||||
}
|
||||
}
|
||||
|
||||
func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) {
|
||||
if sess == nil || conn == nil || len(payload) == 0 {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||||
return
|
||||
}
|
||||
|
||||
sess.connMu.Lock()
|
||||
if sess.conn == conn {
|
||||
sess.connCreateSent = true
|
||||
}
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
|
||||
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -1017,36 +960,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
|
||||
}
|
||||
}
|
||||
|
||||
func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
||||
done := make(chan struct{})
|
||||
if ctx == nil || conn == nil {
|
||||
return done
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
||||
done := make(chan struct{})
|
||||
if ctx == nil || conn == nil {
|
||||
return done
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
_ = conn.SetReadDeadline(time.Now())
|
||||
}
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
|
||||
if len(opts.Metadata) == 0 {
|
||||
return ""
|
||||
@@ -1120,7 +1033,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
|
||||
sess.conn = conn
|
||||
sess.wsURL = wsURL
|
||||
sess.authID = authID
|
||||
sess.connCreateSent = false
|
||||
sess.readerConn = conn
|
||||
sess.connMu.Unlock()
|
||||
|
||||
@@ -1206,7 +1118,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
|
||||
return
|
||||
}
|
||||
sess.conn = nil
|
||||
sess.connCreateSent = false
|
||||
if sess.readerConn == conn {
|
||||
sess.readerConn = nil
|
||||
}
|
||||
@@ -1273,7 +1184,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
|
||||
authID := sess.authID
|
||||
wsURL := sess.wsURL
|
||||
sess.conn = nil
|
||||
sess.connCreateSent = false
|
||||
if sess.readerConn == conn {
|
||||
sess.readerConn = nil
|
||||
}
|
||||
|
||||
36
internal/runtime/executor/codex_websockets_executor_test.go
Normal file
36
internal/runtime/executor/codex_websockets_executor_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
|
||||
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
|
||||
t.Fatalf("type = %s, want response.create", got)
|
||||
}
|
||||
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
|
||||
t.Fatalf("previous_response_id = %s, want resp-1", got)
|
||||
}
|
||||
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
|
||||
t.Fatalf("input item id mismatch")
|
||||
}
|
||||
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
|
||||
t.Fatalf("unexpected websocket request type: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "")
|
||||
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
@@ -81,7 +80,7 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
|
||||
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(req)
|
||||
applyGeminiCLIHeaders(req, "unknown")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -189,7 +188,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP)
|
||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||
reqHTTP.Header.Set("Accept", "application/json")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
@@ -334,7 +333,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP)
|
||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||
reqHTTP.Header.Set("Accept", "text/event-stream")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
@@ -515,7 +514,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP)
|
||||
applyGeminiCLIHeaders(reqHTTP, baseModel)
|
||||
reqHTTP.Header.Set("Accept", "application/json")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
@@ -738,21 +737,11 @@ func stringValue(m map[string]any, key string) string {
|
||||
}
|
||||
|
||||
// applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream.
|
||||
func applyGeminiCLIHeaders(r *http.Request) {
|
||||
var ginHeaders http.Header
|
||||
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "google-api-nodejs-client/9.15.1")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Goog-Api-Client", "gl-node/22.17.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Client-Metadata", geminiCLIClientMetadata())
|
||||
}
|
||||
|
||||
// geminiCLIClientMetadata returns a compact metadata string required by upstream.
|
||||
func geminiCLIClientMetadata() string {
|
||||
// Keep parity with CLI client defaults
|
||||
return "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
// User-Agent is always forced to the GeminiCLI format regardless of the client's value,
|
||||
// so that upstream identifies the request as a native GeminiCLI client.
|
||||
func applyGeminiCLIHeaders(r *http.Request, model string) {
|
||||
r.Header.Set("User-Agent", misc.GeminiCLIUserAgent(model))
|
||||
r.Header.Set("X-Goog-Api-Client", misc.GeminiCLIApiClientHeader)
|
||||
}
|
||||
|
||||
// cliPreviewFallbackOrder returns preview model candidates for a base model.
|
||||
|
||||
@@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
@@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
action := getVertexAction(baseModel, true)
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||
// Imagen models don't support streaming, skip SSE params
|
||||
@@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -490,18 +491,46 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
|
||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||
|
||||
initiator := "user"
|
||||
if len(body) > 0 {
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
for _, msg := range messages.Array() {
|
||||
role := msg.Get("role").String()
|
||||
if role == "assistant" || role == "tool" {
|
||||
initiator = "agent"
|
||||
break
|
||||
}
|
||||
if role := detectLastConversationRole(body); role == "assistant" || role == "tool" {
|
||||
initiator = "agent"
|
||||
}
|
||||
r.Header.Set("X-Initiator", initiator)
|
||||
}
|
||||
|
||||
func detectLastConversationRole(body []byte) string {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
arr := messages.Array()
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
if role := arr[i].Get("role").String(); role != "" {
|
||||
return role
|
||||
}
|
||||
}
|
||||
}
|
||||
r.Header.Set("X-Initiator", initiator)
|
||||
|
||||
if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() {
|
||||
arr := inputs.Array()
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
item := arr[i]
|
||||
|
||||
// Most Responses input items carry a top-level role.
|
||||
if role := item.Get("role").String(); role != "" {
|
||||
return role
|
||||
}
|
||||
|
||||
switch item.Get("type").String() {
|
||||
case "function_call", "function_call_arguments", "computer_call":
|
||||
return "assistant"
|
||||
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
|
||||
return "tool"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// detectVisionContent checks if the request body contains vision/image content.
|
||||
@@ -803,6 +832,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
if isGitHubCopilotResponsesBuiltinTool(toolType) {
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
|
||||
continue
|
||||
}
|
||||
// Accept OpenAI format (type="function") and Claude format
|
||||
// (no type field, but has top-level name + input_schema).
|
||||
if toolType != "" && toolType != "function" {
|
||||
@@ -850,6 +883,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
}
|
||||
if toolChoice.Type == gjson.JSON {
|
||||
choiceType := toolChoice.Get("type").String()
|
||||
if isGitHubCopilotResponsesBuiltinTool(choiceType) {
|
||||
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(toolChoice.Raw))
|
||||
return body
|
||||
}
|
||||
if choiceType == "function" {
|
||||
name := toolChoice.Get("name").String()
|
||||
if name == "" {
|
||||
@@ -867,6 +904,15 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
func isGitHubCopilotResponsesBuiltinTool(toolType string) bool {
|
||||
switch strings.TrimSpace(toolType) {
|
||||
case "computer", "computer_use_preview":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func collectTextFromNode(node gjson.Result) string {
|
||||
if !node.Exists() {
|
||||
return ""
|
||||
@@ -1236,3 +1282,99 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
func isHTTPSuccess(statusCode int) bool {
|
||||
return statusCode >= 200 && statusCode < 300
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultCopilotContextLength is the default context window for unknown Copilot models.
|
||||
defaultCopilotContextLength = 128000
|
||||
// defaultCopilotMaxCompletionTokens is the default max output tokens for unknown Copilot models.
|
||||
defaultCopilotMaxCompletionTokens = 16384
|
||||
)
|
||||
|
||||
// FetchGitHubCopilotModels dynamically fetches available models from the GitHub Copilot API.
|
||||
// It exchanges the GitHub access token stored in auth.Metadata for a Copilot API token,
|
||||
// then queries the /models endpoint. Falls back to the static registry on any failure.
|
||||
func FetchGitHubCopilotModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
if auth == nil {
|
||||
log.Debug("github-copilot: auth is nil, using static models")
|
||||
return registry.GetGitHubCopilotModels()
|
||||
}
|
||||
|
||||
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||
if accessToken == "" {
|
||||
log.Debug("github-copilot: no access_token in auth metadata, using static models")
|
||||
return registry.GetGitHubCopilotModels()
|
||||
}
|
||||
|
||||
copilotAuth := copilotauth.NewCopilotAuth(cfg)
|
||||
|
||||
entries, err := copilotAuth.ListModelsWithGitHubToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
log.Warnf("github-copilot: failed to fetch dynamic models: %v, using static models", err)
|
||||
return registry.GetGitHubCopilotModels()
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
log.Debug("github-copilot: API returned no models, using static models")
|
||||
return registry.GetGitHubCopilotModels()
|
||||
}
|
||||
|
||||
// Build a lookup from the static definitions so we can enrich dynamic entries
|
||||
// with known context lengths, thinking support, etc.
|
||||
staticMap := make(map[string]*registry.ModelInfo)
|
||||
for _, m := range registry.GetGitHubCopilotModels() {
|
||||
staticMap[m.ID] = m
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
models := make([]*registry.ModelInfo, 0, len(entries))
|
||||
seen := make(map[string]struct{}, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.ID == "" {
|
||||
continue
|
||||
}
|
||||
// Deduplicate model IDs to avoid incorrect reference counting.
|
||||
if _, dup := seen[entry.ID]; dup {
|
||||
continue
|
||||
}
|
||||
seen[entry.ID] = struct{}{}
|
||||
|
||||
m := ®istry.ModelInfo{
|
||||
ID: entry.ID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
if entry.Created > 0 {
|
||||
m.Created = entry.Created
|
||||
}
|
||||
if entry.Name != "" {
|
||||
m.DisplayName = entry.Name
|
||||
} else {
|
||||
m.DisplayName = entry.ID
|
||||
}
|
||||
|
||||
// Merge known metadata from the static fallback list
|
||||
if static, ok := staticMap[entry.ID]; ok {
|
||||
if m.DisplayName == entry.ID && static.DisplayName != "" {
|
||||
m.DisplayName = static.DisplayName
|
||||
}
|
||||
m.Description = static.Description
|
||||
m.ContextLength = static.ContextLength
|
||||
m.MaxCompletionTokens = static.MaxCompletionTokens
|
||||
m.SupportedEndpoints = static.SupportedEndpoints
|
||||
m.Thinking = static.Thinking
|
||||
} else {
|
||||
// Sensible defaults for models not in the static list
|
||||
m.Description = entry.ID + " via GitHub Copilot"
|
||||
m.ContextLength = defaultCopilotContextLength
|
||||
m.MaxCompletionTokens = defaultCopilotMaxCompletionTokens
|
||||
}
|
||||
|
||||
models = append(models, m)
|
||||
}
|
||||
|
||||
log.Infof("github-copilot: fetched %d models from API", len(models))
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -262,15 +262,15 @@ func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
|
||||
func TestApplyHeaders_XInitiator_UserWhenLastRoleIsUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Claude Code typical flow: last message is user (tool result), but has assistant in history
|
||||
// Last role governs the initiator decision.
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -285,6 +285,39 @@ func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastAssistantMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"Hi"}]},{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello"}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last role is assistant)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastUserMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastFunctionCallOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"Use tool"}]},{"type":"function_call","call_id":"c1","name":"Read","arguments":"{}"},{"type":"function_call_output","call_id":"c1","output":"ok"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last item maps to tool role)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for x-github-api-version header (Problem M) ---
|
||||
|
||||
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||
|
||||
@@ -2458,7 +2458,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers
|
||||
var totalUsage usage.Detail
|
||||
var hasToolUses bool // Track if any tool uses were emitted
|
||||
var hasTruncatedTools bool // Track if any tool uses were truncated
|
||||
var upstreamStopReason string // Track stop_reason from upstream events
|
||||
|
||||
// Tool use state tracking for input buffering and deduplication
|
||||
@@ -3286,59 +3285,9 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
|
||||
// Emit completed tool uses
|
||||
for _, tu := range completedToolUses {
|
||||
// Check if this tool was truncated - emit with SOFT_LIMIT_REACHED marker
|
||||
// Skip truncated tools - don't emit fake marker tool_use
|
||||
if tu.IsTruncated {
|
||||
hasTruncatedTools = true
|
||||
log.Infof("kiro: streamToChannel emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", tu.Name, tu.ToolUseID)
|
||||
|
||||
// Close text block if open
|
||||
if isTextBlockOpen && contentBlockIndex >= 0 {
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
|
||||
contentBlockIndex++
|
||||
|
||||
// Emit tool_use with SOFT_LIMIT_REACHED marker input
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
|
||||
// Build SOFT_LIMIT_REACHED marker input
|
||||
markerInput := map[string]interface{}{
|
||||
"_status": "SOFT_LIMIT_REACHED",
|
||||
"_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.",
|
||||
}
|
||||
|
||||
markerJSON, _ := json.Marshal(markerInput)
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(markerJSON), contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
|
||||
// Close tool_use block
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
|
||||
hasToolUses = true // Keep this so stop_reason = tool_use
|
||||
log.Warnf("kiro: streamToChannel skipping truncated tool: %s (ID: %s)", tu.Name, tu.ToolUseID)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -3640,12 +3589,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
|
||||
// Determine stop reason: prefer upstream, then detect tool_use, default to end_turn
|
||||
// SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop
|
||||
stopReason := upstreamStopReason
|
||||
if hasTruncatedTools {
|
||||
// Log that we're using SOFT_LIMIT_REACHED approach
|
||||
log.Infof("kiro: streamToChannel using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use for truncated tools")
|
||||
}
|
||||
if stopReason == "" {
|
||||
if hasToolUses {
|
||||
stopReason = "tool_use"
|
||||
|
||||
@@ -257,7 +257,10 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma
|
||||
if suffixResult.HasSuffix {
|
||||
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
|
||||
} else {
|
||||
config = extractThinkingConfig(body, toFormat)
|
||||
config = extractThinkingConfig(body, fromFormat)
|
||||
if !hasThinkingConfig(config) && fromFormat != toFormat {
|
||||
config = extractThinkingConfig(body, toFormat)
|
||||
}
|
||||
}
|
||||
|
||||
if !hasThinkingConfig(config) {
|
||||
@@ -293,7 +296,10 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
|
||||
if config.Mode != ModeLevel {
|
||||
return config
|
||||
}
|
||||
if !isBudgetBasedProvider(toFormat) || !isLevelBasedProvider(fromFormat) {
|
||||
if toFormat == "claude" {
|
||||
return config
|
||||
}
|
||||
if !isBudgetCapableProvider(toFormat) {
|
||||
return config
|
||||
}
|
||||
budget, ok := ConvertLevelToBudget(string(config.Level))
|
||||
@@ -353,6 +359,26 @@ func extractClaudeConfig(body []byte) ThinkingConfig {
|
||||
if thinkingType == "disabled" {
|
||||
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||
}
|
||||
if thinkingType == "adaptive" || thinkingType == "auto" {
|
||||
// Claude adaptive thinking uses output_config.effort (low/medium/high/max).
|
||||
// We only treat it as a thinking config when effort is explicitly present;
|
||||
// otherwise we passthrough and let upstream defaults apply.
|
||||
if effort := gjson.GetBytes(body, "output_config.effort"); effort.Exists() && effort.Type == gjson.String {
|
||||
value := strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
if value == "" {
|
||||
return ThinkingConfig{}
|
||||
}
|
||||
switch value {
|
||||
case "none":
|
||||
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||
case "auto":
|
||||
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
|
||||
default:
|
||||
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
|
||||
}
|
||||
}
|
||||
return ThinkingConfig{}
|
||||
}
|
||||
|
||||
// Check budget_tokens
|
||||
if budget := gjson.GetBytes(body, "thinking.budget_tokens"); budget.Exists() {
|
||||
|
||||
55
internal/thinking/apply_user_defined_test.go
Normal file
55
internal/thinking/apply_user_defined_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package thinking_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-user-defined-claude-" + t.Name()
|
||||
modelID := "custom-claude-4-6"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(clientID)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "claude adaptive effort body",
|
||||
model: modelID,
|
||||
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
|
||||
},
|
||||
{
|
||||
name: "suffix level",
|
||||
model: modelID + "(high)",
|
||||
body: []byte(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyThinking() error = %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
|
||||
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
|
||||
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
|
||||
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,9 @@ var levelToBudgetMap = map[string]int{
|
||||
"medium": 8192,
|
||||
"high": 24576,
|
||||
"xhigh": 32768,
|
||||
// "max" is used by Claude adaptive thinking effort. We map it to a large budget
|
||||
// and rely on per-model clamping when converting to budget-only providers.
|
||||
"max": 128000,
|
||||
}
|
||||
|
||||
// ConvertLevelToBudget converts a thinking level to a budget value.
|
||||
@@ -31,6 +34,7 @@ var levelToBudgetMap = map[string]int{
|
||||
// - medium → 8192
|
||||
// - high → 24576
|
||||
// - xhigh → 32768
|
||||
// - max → 128000
|
||||
//
|
||||
// Returns:
|
||||
// - budget: The converted budget value
|
||||
@@ -92,6 +96,43 @@ func ConvertBudgetToLevel(budget int) (string, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// HasLevel reports whether the given target level exists in the levels slice.
|
||||
// Matching is case-insensitive with leading/trailing whitespace trimmed.
|
||||
func HasLevel(levels []string, target string) bool {
|
||||
for _, level := range levels {
|
||||
if strings.EqualFold(strings.TrimSpace(level), target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MapToClaudeEffort maps a generic thinking level string to a Claude adaptive
|
||||
// thinking effort value (low/medium/high/max).
|
||||
//
|
||||
// supportsMax indicates whether the target model supports "max" effort.
|
||||
// Returns the mapped effort and true if the level is valid, or ("", false) otherwise.
|
||||
func MapToClaudeEffort(level string, supportsMax bool) (string, bool) {
|
||||
level = strings.ToLower(strings.TrimSpace(level))
|
||||
switch level {
|
||||
case "":
|
||||
return "", false
|
||||
case "minimal":
|
||||
return "low", true
|
||||
case "low", "medium", "high":
|
||||
return level, true
|
||||
case "xhigh", "max":
|
||||
if supportsMax {
|
||||
return "max", true
|
||||
}
|
||||
return "high", true
|
||||
case "auto":
|
||||
return "high", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// ModelCapability describes the thinking format support of a model.
|
||||
type ModelCapability int
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// Package claude implements thinking configuration scaffolding for Claude models.
|
||||
//
|
||||
// Claude models use the thinking.budget_tokens format with values in the range
|
||||
// 1024-128000. Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5),
|
||||
// while older models do not.
|
||||
// Claude models support two thinking control styles:
|
||||
// - Manual thinking: thinking.type="enabled" with thinking.budget_tokens (token budget)
|
||||
// - Adaptive thinking (Claude 4.6): thinking.type="adaptive" with output_config.effort (low/medium/high/max)
|
||||
//
|
||||
// Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5), while older models do not.
|
||||
// See: _bmad-output/planning-artifacts/architecture.md#Epic-6
|
||||
package claude
|
||||
|
||||
@@ -34,7 +36,11 @@ func init() {
|
||||
// - Budget clamping to model range
|
||||
// - ZeroAllowed constraint enforcement
|
||||
//
|
||||
// Apply only processes ModeBudget and ModeNone; other modes are passed through unchanged.
|
||||
// Apply processes:
|
||||
// - ModeBudget: manual thinking budget_tokens
|
||||
// - ModeLevel: adaptive thinking effort (Claude 4.6)
|
||||
// - ModeAuto: provider default adaptive/manual behavior
|
||||
// - ModeNone: disabled
|
||||
//
|
||||
// Expected output format when enabled:
|
||||
//
|
||||
@@ -45,6 +51,17 @@ func init() {
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Expected output format for adaptive:
|
||||
//
|
||||
// {
|
||||
// "thinking": {
|
||||
// "type": "adaptive"
|
||||
// },
|
||||
// "output_config": {
|
||||
// "effort": "high"
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Expected output format when disabled:
|
||||
//
|
||||
// {
|
||||
@@ -60,30 +77,91 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// Only process ModeBudget and ModeNone; other modes pass through
|
||||
// (caller should use ValidateConfig first to normalize modes)
|
||||
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
// Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced)
|
||||
// Decide enabled/disabled based on budget value
|
||||
if config.Budget == 0 {
|
||||
supportsAdaptive := modelInfo != nil && modelInfo.Thinking != nil && len(modelInfo.Thinking.Levels) > 0
|
||||
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
return result, nil
|
||||
|
||||
case thinking.ModeLevel:
|
||||
// Adaptive thinking effort is only valid when the model advertises discrete levels.
|
||||
// (Claude 4.6 uses output_config.effort.)
|
||||
if supportsAdaptive && config.Level != "" {
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
result, _ = sjson.SetBytes(result, "output_config.effort", string(config.Level))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Fallback for non-adaptive Claude models: convert level to budget_tokens.
|
||||
if budget, ok := thinking.ConvertLevelToBudget(string(config.Level)); ok {
|
||||
config.Mode = thinking.ModeBudget
|
||||
config.Budget = budget
|
||||
config.Level = ""
|
||||
} else {
|
||||
return body, nil
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case thinking.ModeBudget:
|
||||
// Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced).
|
||||
// Decide enabled/disabled based on budget value.
|
||||
if config.Budget == 0 {
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint).
|
||||
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
||||
return result, nil
|
||||
|
||||
case thinking.ModeAuto:
|
||||
// For Claude 4.6 models, auto maps to adaptive thinking with upstream defaults.
|
||||
if supportsAdaptive {
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
// Explicit effort is optional for adaptive thinking; omit it to allow upstream default.
|
||||
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Legacy fallback: enable thinking without specifying budget_tokens.
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
return result, nil
|
||||
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
|
||||
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
|
||||
@@ -141,7 +219,7 @@ func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo)
|
||||
}
|
||||
|
||||
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto && config.Mode != thinking.ModeLevel {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -153,14 +231,36 @@ func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
case thinking.ModeNone:
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
return result, nil
|
||||
case thinking.ModeAuto:
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
return result, nil
|
||||
case thinking.ModeLevel:
|
||||
// For user-defined models, interpret ModeLevel as Claude adaptive thinking effort.
|
||||
// Upstream is responsible for validating whether the target model supports it.
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
result, _ = sjson.SetBytes(result, "output_config.effort", string(config.Level))
|
||||
return result, nil
|
||||
default:
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -68,7 +66,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
effort := ""
|
||||
support := modelInfo.Thinking
|
||||
if config.Budget == 0 {
|
||||
if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) {
|
||||
if support.ZeroAllowed || thinking.HasLevel(support.Levels, string(thinking.LevelNone)) {
|
||||
effort = string(thinking.LevelNone)
|
||||
}
|
||||
}
|
||||
@@ -120,12 +118,3 @@ func applyCompatibleCodex(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func hasLevel(levels []string, target string) bool {
|
||||
for _, level := range levels {
|
||||
if strings.EqualFold(strings.TrimSpace(level), target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
||||
//
|
||||
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
|
||||
// (low/medium/high). The provider strips any existing thinking config and applies
|
||||
// the unified ThinkingConfig in OpenAI format.
|
||||
// Kimi models use the OpenAI-compatible reasoning_effort format for enabled thinking
|
||||
// levels, but use thinking.type=disabled when thinking is explicitly turned off.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
@@ -17,8 +16,8 @@ import (
|
||||
// Applier implements thinking.ProviderApplier for Kimi models.
|
||||
//
|
||||
// Kimi-specific behavior:
|
||||
// - Output format: reasoning_effort (string: low/medium/high)
|
||||
// - Uses OpenAI-compatible format
|
||||
// - Enabled thinking: reasoning_effort (string levels)
|
||||
// - Disabled thinking: thinking.type="disabled"
|
||||
// - Supports budget-to-level conversion
|
||||
type Applier struct{}
|
||||
|
||||
@@ -35,11 +34,19 @@ func init() {
|
||||
|
||||
// Apply applies thinking configuration to Kimi request body.
|
||||
//
|
||||
// Expected output format:
|
||||
// Expected output format (enabled):
|
||||
//
|
||||
// {
|
||||
// "reasoning_effort": "high"
|
||||
// }
|
||||
//
|
||||
// Expected output format (disabled):
|
||||
//
|
||||
// {
|
||||
// "thinking": {
|
||||
// "type": "disabled"
|
||||
// }
|
||||
// }
|
||||
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||
if thinking.IsUserDefinedModel(modelInfo) {
|
||||
return applyCompatibleKimi(body, config)
|
||||
@@ -60,8 +67,13 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
// Kimi uses "none" to disable thinking
|
||||
effort = string(thinking.LevelNone)
|
||||
// Respect clamped fallback level for models that cannot disable thinking.
|
||||
if config.Level != "" && config.Level != thinking.LevelNone {
|
||||
effort = string(config.Level)
|
||||
break
|
||||
}
|
||||
// Kimi requires explicit disabled thinking object.
|
||||
return applyDisabledThinking(body)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level using threshold mapping
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
@@ -79,12 +91,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
if effort == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
return applyReasoningEffort(body, effort)
|
||||
}
|
||||
|
||||
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
||||
@@ -101,7 +108,9 @@ func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, e
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
effort = string(thinking.LevelNone)
|
||||
if config.Level == "" || config.Level == thinking.LevelNone {
|
||||
return applyDisabledThinking(body)
|
||||
}
|
||||
if config.Level != "" {
|
||||
effort = string(config.Level)
|
||||
}
|
||||
@@ -118,9 +127,33 @@ func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, e
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
return applyReasoningEffort(body, effort)
|
||||
}
|
||||
|
||||
func applyReasoningEffort(body []byte, effort string) ([]byte, error) {
|
||||
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
|
||||
if errDeleteThinking != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
|
||||
}
|
||||
result, errSetEffort := sjson.SetBytes(result, "reasoning_effort", effort)
|
||||
if errSetEffort != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", errSetEffort)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func applyDisabledThinking(body []byte) ([]byte, error) {
|
||||
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
|
||||
if errDeleteThinking != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
|
||||
}
|
||||
result, errDeleteEffort := sjson.DeleteBytes(result, "reasoning_effort")
|
||||
if errDeleteEffort != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to clear reasoning_effort: %w", errDeleteEffort)
|
||||
}
|
||||
result, errSetType := sjson.SetBytes(result, "thinking.type", "disabled")
|
||||
if errSetType != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set thinking.type: %w", errSetType)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
72
internal/thinking/provider/kimi/apply_test.go
Normal file
72
internal/thinking/provider/kimi/apply_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApply_ModeNone_UsesDisabledThinking(t *testing.T) {
|
||||
applier := NewApplier()
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: "kimi-k2.5",
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
}
|
||||
body := []byte(`{"model":"kimi-k2.5","reasoning_effort":"none","thinking":{"type":"enabled","budget_tokens":2048}}`)
|
||||
|
||||
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
|
||||
if errApply != nil {
|
||||
t.Fatalf("Apply() error = %v", errApply)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
|
||||
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
|
||||
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "reasoning_effort").Exists() {
|
||||
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApply_ModeLevel_UsesReasoningEffort(t *testing.T) {
|
||||
applier := NewApplier()
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: "kimi-k2.5",
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
}
|
||||
body := []byte(`{"model":"kimi-k2.5","thinking":{"type":"disabled"}}`)
|
||||
|
||||
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeLevel, Level: thinking.LevelHigh}, modelInfo)
|
||||
if errApply != nil {
|
||||
t.Fatalf("Apply() error = %v", errApply)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "reasoning_effort").String(); got != "high" {
|
||||
t.Fatalf("reasoning_effort = %q, want %q, body=%s", got, "high", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "thinking").Exists() {
|
||||
t.Fatalf("thinking should be removed when reasoning_effort is used, body=%s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApply_UserDefinedModeNone_UsesDisabledThinking(t *testing.T) {
|
||||
applier := NewApplier()
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: "custom-kimi-model",
|
||||
UserDefined: true,
|
||||
}
|
||||
body := []byte(`{"model":"custom-kimi-model","reasoning_effort":"none"}`)
|
||||
|
||||
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
|
||||
if errApply != nil {
|
||||
t.Fatalf("Apply() error = %v", errApply)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
|
||||
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "reasoning_effort").Exists() {
|
||||
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
|
||||
}
|
||||
}
|
||||
@@ -6,8 +6,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -65,7 +63,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
effort := ""
|
||||
support := modelInfo.Thinking
|
||||
if config.Budget == 0 {
|
||||
if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) {
|
||||
if support.ZeroAllowed || thinking.HasLevel(support.Levels, string(thinking.LevelNone)) {
|
||||
effort = string(thinking.LevelNone)
|
||||
}
|
||||
}
|
||||
@@ -117,12 +115,3 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func hasLevel(levels []string, target string) bool {
|
||||
for _, level := range levels {
|
||||
if strings.EqualFold(strings.TrimSpace(level), target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -30,13 +30,18 @@ func StripThinkingConfig(body []byte, provider string) []byte {
|
||||
var paths []string
|
||||
switch provider {
|
||||
case "claude":
|
||||
paths = []string{"thinking"}
|
||||
paths = []string{"thinking", "output_config.effort"}
|
||||
case "gemini":
|
||||
paths = []string{"generationConfig.thinkingConfig"}
|
||||
case "gemini-cli", "antigravity":
|
||||
paths = []string{"request.generationConfig.thinkingConfig"}
|
||||
case "openai":
|
||||
paths = []string{"reasoning_effort"}
|
||||
case "kimi":
|
||||
paths = []string{
|
||||
"reasoning_effort",
|
||||
"thinking",
|
||||
}
|
||||
case "codex":
|
||||
paths = []string{"reasoning.effort"}
|
||||
case "iflow":
|
||||
@@ -54,5 +59,12 @@ func StripThinkingConfig(body []byte, provider string) []byte {
|
||||
for _, path := range paths {
|
||||
result, _ = sjson.DeleteBytes(result, path)
|
||||
}
|
||||
|
||||
// Avoid leaving an empty output_config object for Claude when effort was the only field.
|
||||
if provider == "claude" {
|
||||
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func ParseSpecialSuffix(rawSuffix string) (mode ThinkingMode, ok bool) {
|
||||
// ParseLevelSuffix attempts to parse a raw suffix as a discrete thinking level.
|
||||
//
|
||||
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as a level.
|
||||
// Only discrete effort levels are valid: minimal, low, medium, high, xhigh.
|
||||
// Only discrete effort levels are valid: minimal, low, medium, high, xhigh, max.
|
||||
// Level matching is case-insensitive.
|
||||
//
|
||||
// Special values (none, auto) are NOT handled by this function; use ParseSpecialSuffix
|
||||
@@ -140,6 +140,8 @@ func ParseLevelSuffix(rawSuffix string) (level ThinkingLevel, ok bool) {
|
||||
return LevelHigh, true
|
||||
case "xhigh":
|
||||
return LevelXHigh, true
|
||||
case "max":
|
||||
return LevelMax, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -54,6 +54,9 @@ const (
|
||||
LevelHigh ThinkingLevel = "high"
|
||||
// LevelXHigh sets extra-high thinking effort
|
||||
LevelXHigh ThinkingLevel = "xhigh"
|
||||
// LevelMax sets maximum thinking effort.
|
||||
// This is currently used by Claude 4.6 adaptive thinking (opus supports "max").
|
||||
LevelMax ThinkingLevel = "max"
|
||||
)
|
||||
|
||||
// ThinkingConfig represents a unified thinking configuration.
|
||||
|
||||
@@ -53,7 +53,17 @@ func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, fromFo
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
allowClampUnsupported := isBudgetBasedProvider(fromFormat) && isLevelBasedProvider(toFormat)
|
||||
// allowClampUnsupported determines whether to clamp unsupported levels instead of returning an error.
|
||||
// This applies when crossing provider families (e.g., openai→gemini, claude→gemini) and the target
|
||||
// model supports discrete levels. Same-family conversions require strict validation.
|
||||
toCapability := detectModelCapability(modelInfo)
|
||||
toHasLevelSupport := toCapability == CapabilityLevelOnly || toCapability == CapabilityHybrid
|
||||
allowClampUnsupported := toHasLevelSupport && !isSameProviderFamily(fromFormat, toFormat)
|
||||
|
||||
// strictBudget determines whether to enforce strict budget range validation.
|
||||
// This applies when: (1) config comes from request body (not suffix), (2) source format is known,
|
||||
// and (3) source and target are in the same provider family. Cross-family or suffix-based configs
|
||||
// are clamped instead of rejected to improve interoperability.
|
||||
strictBudget := !fromSuffix && fromFormat != "" && isSameProviderFamily(fromFormat, toFormat)
|
||||
budgetDerivedFromLevel := false
|
||||
|
||||
@@ -201,7 +211,7 @@ func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupp
|
||||
}
|
||||
|
||||
// standardLevelOrder defines the canonical ordering of thinking levels from lowest to highest.
|
||||
var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh}
|
||||
var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh, LevelMax}
|
||||
|
||||
// clampLevel clamps the given level to the nearest supported level.
|
||||
// On tie, prefers the lower level.
|
||||
@@ -325,7 +335,9 @@ func normalizeLevels(levels []string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func isBudgetBasedProvider(provider string) bool {
|
||||
// isBudgetCapableProvider returns true if the provider supports budget-based thinking.
|
||||
// These providers may also support level-based thinking (hybrid models).
|
||||
func isBudgetCapableProvider(provider string) bool {
|
||||
switch provider {
|
||||
case "gemini", "gemini-cli", "antigravity", "claude":
|
||||
return true
|
||||
@@ -334,15 +346,6 @@ func isBudgetBasedProvider(provider string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isLevelBasedProvider(provider string) bool {
|
||||
switch provider {
|
||||
case "openai", "openai-response", "codex":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isGeminiFamily(provider string) bool {
|
||||
switch provider {
|
||||
case "gemini", "gemini-cli", "antigravity":
|
||||
@@ -352,11 +355,21 @@ func isGeminiFamily(provider string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAIFamily(provider string) bool {
|
||||
switch provider {
|
||||
case "openai", "openai-response", "codex":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isSameProviderFamily(from, to string) bool {
|
||||
if from == to {
|
||||
return true
|
||||
}
|
||||
return isGeminiFamily(from) && isGeminiFamily(to)
|
||||
return (isGeminiFamily(from) && isGeminiFamily(to)) ||
|
||||
(isOpenAIFamily(from) && isOpenAIFamily(to))
|
||||
}
|
||||
|
||||
func abs(x int) int {
|
||||
|
||||
@@ -400,7 +400,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
hasTools := toolDeclCount > 0
|
||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||
thinkingType := thinkingResult.Get("type").String()
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive" || thinkingType == "auto")
|
||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||
|
||||
if hasTools && hasThinking && isClaudeThinking {
|
||||
@@ -431,6 +431,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
||||
}
|
||||
|
||||
// tool_choice
|
||||
toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice")
|
||||
if toolChoiceResult.Exists() {
|
||||
toolChoiceType := ""
|
||||
toolChoiceName := ""
|
||||
if toolChoiceResult.IsObject() {
|
||||
toolChoiceType = toolChoiceResult.Get("type").String()
|
||||
toolChoiceName = toolChoiceResult.Get("name").String()
|
||||
} else if toolChoiceResult.Type == gjson.String {
|
||||
toolChoiceType = toolChoiceResult.String()
|
||||
}
|
||||
|
||||
switch toolChoiceType {
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||
case "any":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
case "tool":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||
switch t.Get("type").String() {
|
||||
@@ -440,10 +467,20 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
case "adaptive", "auto":
|
||||
// For adaptive thinking:
|
||||
// - If output_config.effort is explicitly present, pass through as thinkingLevel.
|
||||
// - Otherwise, treat it as "enabled with target-model maximum" and emit high.
|
||||
// ApplyThinking handles clamping to target model's supported levels.
|
||||
effort := ""
|
||||
if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String {
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
}
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,6 +193,42 @@ func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolChoice_SpecificTool(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-flash-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hi"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "json",
|
||||
"description": "A JSON tool",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": {"type": "tool", "name": "json"}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("gemini-3-flash-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if got := gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" {
|
||||
t.Fatalf("Expected toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got)
|
||||
}
|
||||
allowed := gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array()
|
||||
if len(allowed) != 1 || allowed[0].String() != "json" {
|
||||
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -256,7 +257,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
|
||||
@@ -34,6 +34,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Model
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Let user-provided generationConfig pass through
|
||||
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw))
|
||||
}
|
||||
|
||||
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
|
||||
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
@@ -207,6 +212,33 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else {
|
||||
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
|
||||
}
|
||||
case "input_audio":
|
||||
audioData := item.Get("input_audio.data").String()
|
||||
audioFormat := item.Get("input_audio.format").String()
|
||||
if audioData != "" {
|
||||
audioMimeMap := map[string]string{
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"webm": "audio/webm",
|
||||
"pcm16": "audio/pcm",
|
||||
"g711_ulaw": "audio/basic",
|
||||
"g711_alaw": "audio/basic",
|
||||
}
|
||||
mimeType := "audio/wav"
|
||||
if audioFormat != "" {
|
||||
if mapped, ok := audioMimeMap[audioFormat]; ok {
|
||||
mimeType = mapped
|
||||
} else {
|
||||
mimeType = "audio/" + audioFormat
|
||||
}
|
||||
}
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData)
|
||||
p++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -115,24 +116,47 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
mi := registry.LookupModelInfo(modelName, "claude")
|
||||
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
|
||||
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
|
||||
|
||||
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
|
||||
// validation errors since validate treats same-provider unsupported levels as errors.
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
level := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
switch level {
|
||||
case "":
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
|
||||
if supportsAdaptive {
|
||||
switch level {
|
||||
case "":
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
default:
|
||||
if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok {
|
||||
level = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", level)
|
||||
}
|
||||
} else {
|
||||
switch level {
|
||||
case "":
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -142,16 +166,35 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
if supportsAdaptive {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
default:
|
||||
level, ok := thinking.ConvertBudgetToLevel(budget)
|
||||
if ok {
|
||||
if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM {
|
||||
level = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", level)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -68,17 +69,45 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if v := root.Get("reasoning_effort"); v.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := thinking.ConvertLevelToBudget(effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
mi := registry.LookupModelInfo(modelName, "claude")
|
||||
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
|
||||
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
|
||||
|
||||
// Claude 4.6 supports adaptive thinking with output_config.effort.
|
||||
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
|
||||
// validation errors since validate treats same-provider unsupported levels as errors.
|
||||
if supportsAdaptive {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
default:
|
||||
if budget > 0 {
|
||||
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
|
||||
effort = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", effort)
|
||||
}
|
||||
} else {
|
||||
// Legacy/manual thinking (budget_tokens).
|
||||
budget, ok := thinking.ConvertLevelToBudget(effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -174,46 +203,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
} else if contentResult.Exists() && contentResult.IsArray() {
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "text":
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
|
||||
|
||||
case "image_url":
|
||||
// Convert OpenAI image format to Claude Code format
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Extract base64 data and media type from data URL
|
||||
parts := strings.Split(imageURL, ",")
|
||||
if len(parts) == 2 {
|
||||
mediaTypePart := strings.Split(parts[0], ";")[0]
|
||||
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
||||
data := parts[1]
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.Set(imagePart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||
}
|
||||
}
|
||||
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||
if claudePart != "" {
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -262,11 +254,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
case "tool":
|
||||
// Handle tool result messages conversion
|
||||
toolCallID := message.Get("tool_call_id").String()
|
||||
content := message.Get("content").String()
|
||||
toolContentResult := message.Get("content")
|
||||
|
||||
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
|
||||
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
|
||||
msg, _ = sjson.Set(msg, "content.0.content", content)
|
||||
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
|
||||
if toolResultContentRaw {
|
||||
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
|
||||
} else {
|
||||
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
messageIndex++
|
||||
}
|
||||
@@ -329,3 +326,110 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
|
||||
switch part.Get("type").String() {
|
||||
case "text":
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
return textPart
|
||||
|
||||
case "image_url":
|
||||
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
return docPart
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func convertOpenAIImageURLToClaudePart(imageURL string) string {
|
||||
if imageURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
parts := strings.SplitN(imageURL, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
mediaTypePart := strings.SplitN(parts[0], ";", 2)[0]
|
||||
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
||||
if mediaType == "" {
|
||||
mediaType = "application/octet-stream"
|
||||
}
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
|
||||
return imagePart
|
||||
}
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
|
||||
return imagePart
|
||||
}
|
||||
|
||||
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
|
||||
if !content.Exists() {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if content.Type == gjson.String {
|
||||
return content.String(), false
|
||||
}
|
||||
|
||||
if content.IsArray() {
|
||||
claudeContent := "[]"
|
||||
partCount := 0
|
||||
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Type == gjson.String {
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.String())
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
|
||||
partCount++
|
||||
return true
|
||||
}
|
||||
|
||||
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||
if claudePart != "" {
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
|
||||
partCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if partCount > 0 || len(content.Array()) == 0 {
|
||||
return claudeContent, true
|
||||
}
|
||||
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
if content.IsObject() {
|
||||
claudePart := convertOpenAIContentPartToClaudePart(content)
|
||||
if claudePart != "" {
|
||||
claudeContent := "[]"
|
||||
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
|
||||
return claudeContent, true
|
||||
}
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_work",
|
||||
"arguments": "{\"a\":1}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": [
|
||||
{"type": "text", "text": "tool ok"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolResult := messages[1].Get("content.0")
|
||||
if got := toolResult.Get("type").String(); got != "tool_result" {
|
||||
t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got)
|
||||
}
|
||||
if got := toolResult.Get("tool_use_id").String(); got != "call_1" {
|
||||
t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got)
|
||||
}
|
||||
|
||||
toolContent := toolResult.Get("content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "text" {
|
||||
t.Fatalf("Expected first tool_result part type %q, got %q", "text", got)
|
||||
}
|
||||
if got := toolContent.Get("0.text").String(); got != "tool ok" {
|
||||
t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got)
|
||||
}
|
||||
if got := toolContent.Get("1.type").String(); got != "image" {
|
||||
t.Fatalf("Expected second tool_result part type %q, got %q", "image", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.type").String(); got != "base64" {
|
||||
t.Fatalf("Expected image source type %q, got %q", "base64", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.media_type").String(); got != "image/png" {
|
||||
t.Fatalf("Expected image media type %q, got %q", "image/png", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" {
|
||||
t.Fatalf("Unexpected base64 image data: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_work",
|
||||
"arguments": "{\"a\":1}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/tool.png"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolContent := messages[1].Get("content.0.content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "image" {
|
||||
t.Fatalf("Expected tool_result part type %q, got %q", "image", got)
|
||||
}
|
||||
if got := toolContent.Get("0.source.type").String(); got != "url" {
|
||||
t.Fatalf("Expected image source type %q, got %q", "url", got)
|
||||
}
|
||||
if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" {
|
||||
t.Fatalf("Unexpected image URL: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -56,17 +57,45 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if v := root.Get("reasoning.effort"); v.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := thinking.ConvertLevelToBudget(effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
mi := registry.LookupModelInfo(modelName, "claude")
|
||||
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
|
||||
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
|
||||
|
||||
// Claude 4.6 supports adaptive thinking with output_config.effort.
|
||||
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
|
||||
// validation errors since validate treats same-provider unsupported levels as errors.
|
||||
if supportsAdaptive {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
default:
|
||||
if budget > 0 {
|
||||
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
|
||||
effort = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", effort)
|
||||
}
|
||||
} else {
|
||||
// Legacy/manual thinking (budget_tokens).
|
||||
budget, ok := thinking.ConvertLevelToBudget(effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,15 +46,23 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
if systemsResult.IsArray() {
|
||||
systemResults := systemsResult.Array()
|
||||
message := `{"type":"message","role":"developer","content":[]}`
|
||||
contentIndex := 0
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemResult := systemResults[i]
|
||||
systemTypeResult := systemResult.Get("type")
|
||||
if systemTypeResult.String() == "text" {
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", i), "input_text")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", i), systemResult.Get("text").String())
|
||||
text := systemResult.Get("text").String()
|
||||
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||
continue
|
||||
}
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
}
|
||||
}
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
if contentIndex > 0 {
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
}
|
||||
}
|
||||
|
||||
// Process messages and transform their contents to appropriate formats.
|
||||
@@ -152,7 +160,51 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
flushMessage()
|
||||
functionCallOutputMessage := `{"type":"function_call_output"}`
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
|
||||
contentResult := messageContentResult.Get("content")
|
||||
if contentResult.IsArray() {
|
||||
toolResultContentIndex := 0
|
||||
toolResultContent := `[]`
|
||||
contentResults := contentResult.Array()
|
||||
for k := 0; k < len(contentResults); k++ {
|
||||
toolResultContentType := contentResults[k].Get("type").String()
|
||||
if toolResultContentType == "image" {
|
||||
sourceResult := contentResults[k].Get("source")
|
||||
if sourceResult.Exists() {
|
||||
data := sourceResult.Get("data").String()
|
||||
if data == "" {
|
||||
data = sourceResult.Get("base64").String()
|
||||
}
|
||||
if data != "" {
|
||||
mediaType := sourceResult.Get("media_type").String()
|
||||
if mediaType == "" {
|
||||
mediaType = sourceResult.Get("mime_type").String()
|
||||
}
|
||||
if mediaType == "" {
|
||||
mediaType = "application/octet-stream"
|
||||
}
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
|
||||
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
|
||||
toolResultContentIndex++
|
||||
}
|
||||
}
|
||||
} else if toolResultContentType == "text" {
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
|
||||
toolResultContentIndex++
|
||||
}
|
||||
}
|
||||
if toolResultContent != `[]` {
|
||||
functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent)
|
||||
} else {
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
}
|
||||
} else {
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
}
|
||||
|
||||
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
|
||||
}
|
||||
}
|
||||
@@ -203,6 +255,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw))
|
||||
tool, _ = sjson.Delete(tool, "input_schema")
|
||||
tool, _ = sjson.Delete(tool, "parameters.$schema")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||
tool, _ = sjson.Set(tool, "strict", false)
|
||||
template, _ = sjson.SetRaw(template, "tools.-1", tool)
|
||||
}
|
||||
@@ -222,10 +276,18 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
reasoningEffort = string(thinking.LevelXHigh)
|
||||
case "adaptive", "auto":
|
||||
// Adaptive thinking can carry an explicit effort in output_config.effort (Claude 4.6).
|
||||
// Pass through directly; ApplyThinking handles clamping to target model's levels.
|
||||
effort := ""
|
||||
if v := rootResult.Get("output_config.effort"); v.Exists() && v.Type == gjson.String {
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
reasoningEffort = effort
|
||||
} else {
|
||||
reasoningEffort = string(thinking.LevelXHigh)
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -22,8 +23,8 @@ var (
|
||||
|
||||
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
||||
type ConvertCodexResponseToClaudeParams struct {
|
||||
HasToolCall bool
|
||||
BlockIndex int
|
||||
HasToolCall bool
|
||||
BlockIndex int
|
||||
HasReceivedArgumentsDelta bool
|
||||
}
|
||||
|
||||
@@ -141,7 +142,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||
{
|
||||
// Restore original tool name if shortened
|
||||
name := itemResult.Get("name").String()
|
||||
@@ -310,7 +311,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
}
|
||||
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
inputRaw := "{}"
|
||||
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
|
||||
|
||||
@@ -74,8 +74,13 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
}
|
||||
|
||||
// Extract and set the model version.
|
||||
cachedModel := (*param).(*ConvertCliToOpenAIParams).Model
|
||||
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||
} else if cachedModel != "" {
|
||||
template, _ = sjson.Set(template, "model", cachedModel)
|
||||
} else if modelName != "" {
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
}
|
||||
|
||||
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
modelName := "gpt-5.3-codex"
|
||||
|
||||
out := ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.created","response":{"id":"resp_123","created_at":1700000000,"model":"gpt-5.3-codex"}}`), ¶m)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected no output for response.created, got %d chunks", len(out))
|
||||
}
|
||||
|
||||
out = ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.output_text.delta","delta":"hello"}`), ¶m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
gotModel := gjson.Get(out[0], "model").String()
|
||||
if gotModel != modelName {
|
||||
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
modelName := "gpt-5.3-codex"
|
||||
|
||||
out := ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.output_text.delta","delta":"hello"}`), ¶m)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
gotModel := gjson.Get(out[0], "model").String()
|
||||
if gotModel != modelName {
|
||||
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,12 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
if v := gjson.GetBytes(rawJSON, "service_tier"); v.Exists() {
|
||||
if v.String() != "priority" {
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
}
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||
|
||||
|
||||
@@ -264,18 +264,18 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"user": "test-user",
|
||||
"input": [{"role": "user", "content": "Hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify user field is deleted
|
||||
userField := gjson.Get(outputStr, "user")
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify user field is deleted
|
||||
userField := gjson.Get(outputStr, "user")
|
||||
if userField.Exists() {
|
||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||
}
|
||||
|
||||
@@ -6,24 +6,14 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
||||
// to OpenAI Responses SSE events (response.*).
|
||||
|
||||
func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
if typeResult := gjson.GetBytes(rawJSON, "type"); typeResult.Exists() {
|
||||
typeStr := typeResult.String()
|
||||
if typeStr == "response.created" || typeStr == "response.in_progress" || typeStr == "response.completed" {
|
||||
if gjson.GetBytes(rawJSON, "response.instructions").Exists() {
|
||||
instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String()
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", instructions)
|
||||
}
|
||||
}
|
||||
}
|
||||
out := fmt.Sprintf("data: %s", string(rawJSON))
|
||||
return []string{out}
|
||||
}
|
||||
@@ -32,17 +22,12 @@ func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string
|
||||
|
||||
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
|
||||
// from a non-streaming OpenAI Chat Completions response.
|
||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
// Verify this is a response.completed event
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
}
|
||||
responseResult := rootResult.Get("response")
|
||||
template := responseResult.Raw
|
||||
if responseResult.Get("instructions").Exists() {
|
||||
instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String()
|
||||
template, _ = sjson.Set(template, "instructions", instructions)
|
||||
}
|
||||
return template
|
||||
return responseResult.Raw
|
||||
}
|
||||
|
||||
@@ -156,6 +156,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
tool, _ = sjson.Delete(tool, "type")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
|
||||
if !hasTools {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
|
||||
@@ -171,7 +172,35 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
// tool_choice
|
||||
toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice")
|
||||
if toolChoiceResult.Exists() {
|
||||
toolChoiceType := ""
|
||||
toolChoiceName := ""
|
||||
if toolChoiceResult.IsObject() {
|
||||
toolChoiceType = toolChoiceResult.Get("type").String()
|
||||
toolChoiceName = toolChoiceResult.Get("name").String()
|
||||
} else if toolChoiceResult.Type == gjson.String {
|
||||
toolChoiceType = toolChoiceResult.String()
|
||||
}
|
||||
|
||||
switch toolChoiceType {
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||
case "any":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
case "tool":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini CLI thinkingConfig when enabled
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
@@ -180,10 +209,20 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
case "adaptive", "auto":
|
||||
// For adaptive thinking:
|
||||
// - If output_config.effort is explicitly present, pass through as thinkingLevel.
|
||||
// - Otherwise, treat it as "enabled with target-model maximum" and emit high.
|
||||
// ApplyThinking handles clamping to target model's supported levels.
|
||||
effort := ""
|
||||
if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String {
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
}
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeRequestToCLI_ToolChoice_SpecificTool(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-flash-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hi"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "json",
|
||||
"description": "A JSON tool",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": {"type": "tool", "name": "json"}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false)
|
||||
|
||||
if got := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" {
|
||||
t.Fatalf("Expected request.toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got)
|
||||
}
|
||||
allowed := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array()
|
||||
if len(allowed) != 1 || allowed[0].String() != "json" {
|
||||
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -209,7 +210,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
|
||||
@@ -34,6 +34,11 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
// Model
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Let user-provided generationConfig pass through
|
||||
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw))
|
||||
}
|
||||
|
||||
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
|
||||
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -84,6 +85,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
case "tool_use":
|
||||
functionName := contentResult.Get("name").String()
|
||||
if toolUseID := contentResult.Get("id").String(); toolUseID != "" {
|
||||
if derived := toolNameFromClaudeToolUseID(toolUseID); derived != "" {
|
||||
functionName = derived
|
||||
}
|
||||
}
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() && gjson.Valid(functionArgs) {
|
||||
@@ -99,10 +105,9 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
if toolCallID == "" {
|
||||
return true
|
||||
}
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
funcName := toolNameFromClaudeToolUseID(toolCallID)
|
||||
if funcName == "" {
|
||||
funcName = toolCallID
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
@@ -136,6 +141,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
tool, _ = sjson.Delete(tool, "type")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
|
||||
if !hasTools {
|
||||
out, _ = sjson.SetRaw(out, "tools", `[{"functionDeclarations":[]}]`)
|
||||
@@ -151,7 +157,34 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
||||
// tool_choice
|
||||
toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice")
|
||||
if toolChoiceResult.Exists() {
|
||||
toolChoiceType := ""
|
||||
toolChoiceName := ""
|
||||
if toolChoiceResult.IsObject() {
|
||||
toolChoiceType = toolChoiceResult.Get("type").String()
|
||||
toolChoiceName = toolChoiceResult.Get("name").String()
|
||||
} else if toolChoiceResult.Type == gjson.String {
|
||||
toolChoiceType = toolChoiceResult.String()
|
||||
}
|
||||
|
||||
switch toolChoiceType {
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "NONE")
|
||||
case "any":
|
||||
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY")
|
||||
case "tool":
|
||||
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinking config when enabled
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
switch t.Get("type").String() {
|
||||
@@ -161,10 +194,28 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
case "adaptive", "auto":
|
||||
// For adaptive thinking:
|
||||
// - If output_config.effort is explicitly present, pass through as thinkingLevel.
|
||||
// - Otherwise, treat it as "enabled with target-model maximum" and emit thinkingBudget=max.
|
||||
// ApplyThinking handles clamping to target model's supported levels.
|
||||
effort := ""
|
||||
if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String {
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
} else {
|
||||
maxBudget := 0
|
||||
if mi := registry.LookupModelInfo(modelName, "gemini"); mi != nil && mi.Thinking != nil {
|
||||
maxBudget = mi.Thinking.Max
|
||||
}
|
||||
if maxBudget > 0 {
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", maxBudget)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
}
|
||||
}
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
@@ -183,3 +234,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func toolNameFromClaudeToolUseID(toolUseID string) string {
|
||||
parts := strings.Split(toolUseID, "-")
|
||||
if len(parts) <= 1 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(parts[0:len(parts)-1], "-")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeRequestToGemini_ToolChoice_SpecificTool(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-flash-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hi"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "json",
|
||||
"description": "A JSON tool",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": {"type": "tool", "name": "json"}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
|
||||
|
||||
if got := gjson.GetBytes(output, "toolConfig.functionCallingConfig.mode").String(); got != "ANY" {
|
||||
t.Fatalf("Expected toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got)
|
||||
}
|
||||
allowed := gjson.GetBytes(output, "toolConfig.functionCallingConfig.allowedFunctionNames").Array()
|
||||
if len(allowed) != 1 || allowed[0].String() != "json" {
|
||||
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
|
||||
}
|
||||
}
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -25,6 +25,8 @@ type Params struct {
|
||||
ResponseType int
|
||||
ResponseIndex int
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
ToolNameMap map[string]string
|
||||
SawToolCall bool
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -53,6 +55,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.ToolNameMapFromClaudeRequest(originalRequestRawJSON),
|
||||
SawToolCall: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,8 +70,6 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Track whether tools are being used in this response chunk
|
||||
usedTool := false
|
||||
output := ""
|
||||
|
||||
// Initialize the streaming session with a message_start event
|
||||
@@ -175,12 +177,13 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude API compatibility
|
||||
usedTool = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
(*param).(*Params).SawToolCall = true
|
||||
upstreamToolName := functionCallResult.Get("name").String()
|
||||
clientToolName := util.MapToolName((*param).(*Params).ToolNameMap, upstreamToolName)
|
||||
|
||||
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
|
||||
// If we are already in tool use mode and name is empty, treat as continuation (delta).
|
||||
if (*param).(*Params).ResponseType == 3 && fcName == "" {
|
||||
if (*param).(*Params).ResponseType == 3 && upstreamToolName == "" {
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
@@ -221,8 +224,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.Set(data, "content_block.name", clientToolName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
@@ -249,7 +252,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
output = output + `data: `
|
||||
|
||||
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
if usedTool {
|
||||
if (*param).(*Params).SawToolCall {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
} else if finish := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
@@ -278,10 +281,10 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
// Returns:
|
||||
// - string: A Claude-compatible JSON response.
|
||||
func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
_ = originalRequestRawJSON
|
||||
_ = requestRawJSON
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
|
||||
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("responseId").String())
|
||||
@@ -336,11 +339,12 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
flushText()
|
||||
hasToolCall = true
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
upstreamToolName := functionCall.Get("name").String()
|
||||
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
|
||||
toolIDCounter++
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
|
||||
inputRaw := "{}"
|
||||
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
|
||||
inputRaw = args.Raw
|
||||
|
||||
@@ -34,6 +34,11 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Model
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Let user-provided generationConfig pass through
|
||||
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "generationConfig", []byte(genConfig.Raw))
|
||||
}
|
||||
|
||||
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini thinkingConfig.
|
||||
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
@@ -142,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
content := m.Get("content")
|
||||
|
||||
if (role == "system" || role == "developer") && len(arr) > 1 {
|
||||
// system -> system_instruction as a user message style
|
||||
// system -> systemInstruction as a user message style
|
||||
if content.Type == gjson.String {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String())
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
|
||||
systemPartIndex++
|
||||
} else if content.IsObject() && content.Get("type").String() == "text" {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
||||
systemPartIndex++
|
||||
} else if content.IsArray() {
|
||||
contents := content.Array()
|
||||
if len(contents) > 0 {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
for j := 0; j < len(contents); j++ {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
||||
systemPartIndex++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
if instructions := root.Get("instructions"); instructions.Exists() {
|
||||
systemInstr := `{"parts":[{"text":""}]}`
|
||||
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||
}
|
||||
|
||||
// Convert input messages to Gemini contents format
|
||||
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
if strings.EqualFold(itemRole, "system") {
|
||||
if contentArray := item.Get("content"); contentArray.Exists() {
|
||||
systemInstr := ""
|
||||
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
|
||||
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
|
||||
systemInstr = systemInstructionResult.Raw
|
||||
} else {
|
||||
systemInstr = `{"parts":[]}`
|
||||
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
if systemInstr != `{"parts":[]}` {
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||
}
|
||||
}
|
||||
continue
|
||||
@@ -237,6 +237,33 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
|
||||
}
|
||||
}
|
||||
case "input_audio":
|
||||
audioData := contentItem.Get("data").String()
|
||||
audioFormat := contentItem.Get("format").String()
|
||||
if audioData != "" {
|
||||
audioMimeMap := map[string]string{
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"webm": "audio/webm",
|
||||
"pcm16": "audio/pcm",
|
||||
"g711_ulaw": "audio/basic",
|
||||
"g711_alaw": "audio/basic",
|
||||
}
|
||||
mimeType := "audio/wav"
|
||||
if audioFormat != "" {
|
||||
if mapped, ok := audioMimeMap[audioFormat]; ok {
|
||||
mimeType = mapped
|
||||
} else {
|
||||
mimeType = "audio/" + audioFormat
|
||||
}
|
||||
}
|
||||
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
|
||||
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
|
||||
partJSON, _ = sjson.Set(partJSON, "inline_data.data", audioData)
|
||||
}
|
||||
}
|
||||
|
||||
if partJSON != "" {
|
||||
@@ -354,22 +381,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
funcDecl, _ = sjson.Set(funcDecl, "description", desc.String())
|
||||
}
|
||||
if params := tool.Get("parameters"); params.Exists() {
|
||||
// Convert parameter types from OpenAI format to Gemini format
|
||||
cleaned := params.Raw
|
||||
// Convert type values to uppercase for Gemini
|
||||
paramsResult := gjson.Parse(cleaned)
|
||||
if properties := paramsResult.Get("properties"); properties.Exists() {
|
||||
properties.ForEach(func(key, value gjson.Result) bool {
|
||||
if propType := value.Get("type"); propType.Exists() {
|
||||
upperType := strings.ToUpper(propType.String())
|
||||
cleaned, _ = sjson.Set(cleaned, "properties."+key.String()+".type", upperType)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
// Set the overall type to OBJECT
|
||||
cleaned, _ = sjson.Set(cleaned, "type", "OBJECT")
|
||||
funcDecl, _ = sjson.SetRaw(funcDecl, "parametersJsonSchema", cleaned)
|
||||
funcDecl, _ = sjson.SetRaw(funcDecl, "parametersJsonSchema", params.Raw)
|
||||
}
|
||||
|
||||
geminiTools, _ = sjson.SetRaw(geminiTools, "0.functionDeclarations.-1", funcDecl)
|
||||
|
||||
@@ -605,10 +605,6 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
})
|
||||
}
|
||||
|
||||
// Apply dynamic compression if total tools size exceeds threshold
|
||||
// This prevents 500 errors when Claude Code sends too many tools
|
||||
kiroTools = compressToolsIfNeeded(kiroTools)
|
||||
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
@@ -858,34 +854,7 @@ func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserI
|
||||
|
||||
var textContents []KiroTextContent
|
||||
|
||||
// Check if this tool_result contains error from our SOFT_LIMIT_REACHED tool_use
|
||||
// The client will return an error when trying to execute a tool with marker input
|
||||
resultStr := resultContent.String()
|
||||
isSoftLimitError := strings.Contains(resultStr, "SOFT_LIMIT_REACHED") ||
|
||||
strings.Contains(resultStr, "_status") ||
|
||||
strings.Contains(resultStr, "truncated") ||
|
||||
strings.Contains(resultStr, "missing required") ||
|
||||
strings.Contains(resultStr, "invalid input") ||
|
||||
strings.Contains(resultStr, "Error writing file")
|
||||
|
||||
if isError && isSoftLimitError {
|
||||
// Replace error content with SOFT_LIMIT_REACHED guidance
|
||||
log.Infof("kiro: detected SOFT_LIMIT_REACHED in tool_result for %s, replacing with guidance", toolUseID)
|
||||
softLimitMsg := `SOFT_LIMIT_REACHED
|
||||
|
||||
Your previous tool call was incomplete due to API output size limits.
|
||||
The content was PARTIALLY transmitted but NOT executed.
|
||||
|
||||
REQUIRED ACTION:
|
||||
1. Split your content into smaller chunks (max 300 lines per call)
|
||||
2. For file writes: Create file with first chunk, then use append for remaining
|
||||
3. Do NOT regenerate content you already attempted - continue from where you stopped
|
||||
|
||||
STATUS: This is NOT an error. Continue with smaller chunks.`
|
||||
textContents = append(textContents, KiroTextContent{Text: softLimitMsg})
|
||||
// Mark as SUCCESS so Claude doesn't treat it as a failure
|
||||
isError = false
|
||||
} else if resultContent.IsArray() {
|
||||
if resultContent.IsArray() {
|
||||
for _, item := range resultContent.Array() {
|
||||
if item.Get("type").String() == "text" {
|
||||
textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()})
|
||||
|
||||
@@ -55,39 +55,18 @@ func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, u
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool_use blocks - emit truncated tools with SOFT_LIMIT_REACHED marker
|
||||
hasTruncatedTools := false
|
||||
// Add tool_use blocks - skip truncated tools and log warning
|
||||
for _, toolUse := range toolUses {
|
||||
if toolUse.IsTruncated && toolUse.TruncationInfo != nil {
|
||||
// Emit tool_use with SOFT_LIMIT_REACHED marker input
|
||||
hasTruncatedTools = true
|
||||
log.Infof("kiro: buildClaudeResponse emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", toolUse.Name, toolUse.ToolUseID)
|
||||
|
||||
markerInput := map[string]interface{}{
|
||||
"_status": "SOFT_LIMIT_REACHED",
|
||||
"_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.",
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUse.ToolUseID,
|
||||
"name": toolUse.Name,
|
||||
"input": markerInput,
|
||||
})
|
||||
} else {
|
||||
// Normal tool use
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUse.ToolUseID,
|
||||
"name": toolUse.Name,
|
||||
"input": toolUse.Input,
|
||||
})
|
||||
log.Warnf("kiro: buildClaudeResponse skipping truncated tool: %s (ID: %s)", toolUse.Name, toolUse.ToolUseID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Log if we used SOFT_LIMIT_REACHED
|
||||
if hasTruncatedTools {
|
||||
log.Infof("kiro: buildClaudeResponse using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use")
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUse.ToolUseID,
|
||||
"name": toolUse.Name,
|
||||
"input": toolUse.Input,
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure at least one content block (Claude API requires non-empty content)
|
||||
|
||||
@@ -192,8 +192,8 @@ func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
// Capture tool use ID for toolResults handshake
|
||||
if id, ok := cb["id"].(string); ok {
|
||||
// Capture tool use ID only for web_search toolResults handshake
|
||||
if id, ok := cb["id"].(string); ok && (currentToolName == "web_search" || currentToolName == "remote_web_search") {
|
||||
result.WebSearchToolUseId = id
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
// Package claude provides tool compression functionality for Kiro translator.
|
||||
// This file implements dynamic tool compression to reduce tool payload size
|
||||
// when it exceeds the target threshold, preventing 500 errors from Kiro API.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// calculateToolsSize calculates the JSON serialized size of the tools list.
|
||||
// Returns the size in bytes.
|
||||
func calculateToolsSize(tools []KiroToolWrapper) int {
|
||||
if len(tools) == 0 {
|
||||
return 0
|
||||
}
|
||||
data, err := json.Marshal(tools)
|
||||
if err != nil {
|
||||
log.Warnf("kiro: failed to marshal tools for size calculation: %v", err)
|
||||
return 0
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
|
||||
// simplifyInputSchema simplifies the input_schema by keeping only essential fields:
|
||||
// type, enum, required. Recursively processes nested properties.
|
||||
func simplifyInputSchema(schema interface{}) interface{} {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
schemaMap, ok := schema.(map[string]interface{})
|
||||
if !ok {
|
||||
return schema
|
||||
}
|
||||
|
||||
simplified := make(map[string]interface{})
|
||||
|
||||
// Keep essential fields
|
||||
if t, ok := schemaMap["type"]; ok {
|
||||
simplified["type"] = t
|
||||
}
|
||||
if enum, ok := schemaMap["enum"]; ok {
|
||||
simplified["enum"] = enum
|
||||
}
|
||||
if required, ok := schemaMap["required"]; ok {
|
||||
simplified["required"] = required
|
||||
}
|
||||
|
||||
// Recursively process properties
|
||||
if properties, ok := schemaMap["properties"].(map[string]interface{}); ok {
|
||||
simplifiedProps := make(map[string]interface{})
|
||||
for key, value := range properties {
|
||||
simplifiedProps[key] = simplifyInputSchema(value)
|
||||
}
|
||||
simplified["properties"] = simplifiedProps
|
||||
}
|
||||
|
||||
// Process items for array types
|
||||
if items, ok := schemaMap["items"]; ok {
|
||||
simplified["items"] = simplifyInputSchema(items)
|
||||
}
|
||||
|
||||
// Process additionalProperties if present
|
||||
if additionalProps, ok := schemaMap["additionalProperties"]; ok {
|
||||
simplified["additionalProperties"] = simplifyInputSchema(additionalProps)
|
||||
}
|
||||
|
||||
// Process anyOf, oneOf, allOf
|
||||
for _, key := range []string{"anyOf", "oneOf", "allOf"} {
|
||||
if arr, ok := schemaMap[key].([]interface{}); ok {
|
||||
simplifiedArr := make([]interface{}, len(arr))
|
||||
for i, item := range arr {
|
||||
simplifiedArr[i] = simplifyInputSchema(item)
|
||||
}
|
||||
simplified[key] = simplifiedArr
|
||||
}
|
||||
}
|
||||
|
||||
return simplified
|
||||
}
|
||||
|
||||
// compressToolDescription compresses a description to the target length.
|
||||
// Ensures the result is at least MinToolDescriptionLength characters.
|
||||
// Uses UTF-8 safe truncation.
|
||||
func compressToolDescription(description string, targetLength int) string {
|
||||
if targetLength < kirocommon.MinToolDescriptionLength {
|
||||
targetLength = kirocommon.MinToolDescriptionLength
|
||||
}
|
||||
|
||||
if len(description) <= targetLength {
|
||||
return description
|
||||
}
|
||||
|
||||
// Find a safe truncation point (UTF-8 boundary)
|
||||
truncLen := targetLength - 3 // Leave room for "..."
|
||||
|
||||
// Ensure we don't cut in the middle of a UTF-8 character
|
||||
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||
truncLen--
|
||||
}
|
||||
|
||||
if truncLen <= 0 {
|
||||
return description[:kirocommon.MinToolDescriptionLength]
|
||||
}
|
||||
|
||||
return description[:truncLen] + "..."
|
||||
}
|
||||
|
||||
// compressToolsIfNeeded compresses tools if their total size exceeds the target threshold.
|
||||
// Compression strategy:
|
||||
// 1. First, check if compression is needed (size > ToolCompressionTargetSize)
|
||||
// 2. Step 1: Simplify input_schema (keep only type/enum/required)
|
||||
// 3. Step 2: Proportionally compress descriptions (minimum MinToolDescriptionLength chars)
|
||||
// Returns the compressed tools list.
|
||||
func compressToolsIfNeeded(tools []KiroToolWrapper) []KiroToolWrapper {
|
||||
if len(tools) == 0 {
|
||||
return tools
|
||||
}
|
||||
|
||||
originalSize := calculateToolsSize(tools)
|
||||
if originalSize <= kirocommon.ToolCompressionTargetSize {
|
||||
log.Debugf("kiro: tools size %d bytes is within target %d bytes, no compression needed",
|
||||
originalSize, kirocommon.ToolCompressionTargetSize)
|
||||
return tools
|
||||
}
|
||||
|
||||
log.Infof("kiro: tools size %d bytes exceeds target %d bytes, starting compression",
|
||||
originalSize, kirocommon.ToolCompressionTargetSize)
|
||||
|
||||
// Create a copy of tools to avoid modifying the original
|
||||
compressedTools := make([]KiroToolWrapper, len(tools))
|
||||
for i, tool := range tools {
|
||||
compressedTools[i] = KiroToolWrapper{
|
||||
ToolSpecification: KiroToolSpecification{
|
||||
Name: tool.ToolSpecification.Name,
|
||||
Description: tool.ToolSpecification.Description,
|
||||
InputSchema: KiroInputSchema{JSON: tool.ToolSpecification.InputSchema.JSON},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Simplify input_schema
|
||||
for i := range compressedTools {
|
||||
compressedTools[i].ToolSpecification.InputSchema.JSON =
|
||||
simplifyInputSchema(compressedTools[i].ToolSpecification.InputSchema.JSON)
|
||||
}
|
||||
|
||||
sizeAfterSchemaSimplification := calculateToolsSize(compressedTools)
|
||||
log.Debugf("kiro: size after schema simplification: %d bytes (reduced by %d bytes)",
|
||||
sizeAfterSchemaSimplification, originalSize-sizeAfterSchemaSimplification)
|
||||
|
||||
// Check if we're within target after schema simplification
|
||||
if sizeAfterSchemaSimplification <= kirocommon.ToolCompressionTargetSize {
|
||||
log.Infof("kiro: compression complete after schema simplification, final size: %d bytes",
|
||||
sizeAfterSchemaSimplification)
|
||||
return compressedTools
|
||||
}
|
||||
|
||||
// Step 2: Compress descriptions proportionally
|
||||
sizeToReduce := float64(sizeAfterSchemaSimplification - kirocommon.ToolCompressionTargetSize)
|
||||
var totalDescLen float64
|
||||
for _, tool := range compressedTools {
|
||||
totalDescLen += float64(len(tool.ToolSpecification.Description))
|
||||
}
|
||||
|
||||
if totalDescLen > 0 {
|
||||
// Assume size reduction comes primarily from descriptions.
|
||||
keepRatio := 1.0 - (sizeToReduce / totalDescLen)
|
||||
if keepRatio > 1.0 {
|
||||
keepRatio = 1.0
|
||||
} else if keepRatio < 0 {
|
||||
keepRatio = 0
|
||||
}
|
||||
|
||||
for i := range compressedTools {
|
||||
desc := compressedTools[i].ToolSpecification.Description
|
||||
targetLen := int(float64(len(desc)) * keepRatio)
|
||||
compressedTools[i].ToolSpecification.Description = compressToolDescription(desc, targetLen)
|
||||
}
|
||||
}
|
||||
|
||||
finalSize := calculateToolsSize(compressedTools)
|
||||
log.Infof("kiro: compression complete, original: %d bytes, final: %d bytes (%.1f%% reduction)",
|
||||
originalSize, finalSize, float64(originalSize-finalSize)/float64(originalSize)*100)
|
||||
|
||||
return compressedTools
|
||||
}
|
||||
@@ -84,13 +84,18 @@ func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[stri
|
||||
ParsedFields: make(map[string]string),
|
||||
}
|
||||
|
||||
// Scenario 1: Empty input buffer - no data received at all
|
||||
// Scenario 1: Empty input buffer - only flag as truncation if tool has required fields
|
||||
// Many tools (e.g. TaskList, TaskGet) have no required params, so empty input is valid
|
||||
if strings.TrimSpace(rawInput) == "" {
|
||||
info.IsTruncated = true
|
||||
info.TruncationType = TruncationTypeEmptyInput
|
||||
info.ErrorMessage = "Tool input was completely empty - API response may have been truncated before tool parameters were transmitted"
|
||||
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): empty input buffer",
|
||||
info.TruncationType, toolName, toolUseID)
|
||||
if _, hasRequirements := RequiredFieldsByTool[toolName]; hasRequirements {
|
||||
info.IsTruncated = true
|
||||
info.TruncationType = TruncationTypeEmptyInput
|
||||
info.ErrorMessage = "Tool input was completely empty - API response may have been truncated before tool parameters were transmitted"
|
||||
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): empty input buffer",
|
||||
info.TruncationType, toolName, toolUseID)
|
||||
return info
|
||||
}
|
||||
log.Debugf("kiro: empty input for tool %s (ID: %s) - no required fields, treating as valid", toolName, toolUseID)
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -342,7 +347,7 @@ func buildTruncationErrorMessage(toolName, truncationType string, parsedFields m
|
||||
}
|
||||
|
||||
sb.WriteString(" Received ")
|
||||
sb.WriteString(string(rune(len(rawInput))))
|
||||
sb.WriteString(formatInt(len(rawInput)))
|
||||
sb.WriteString(" bytes. Please retry with smaller content chunks.")
|
||||
|
||||
return sb.String()
|
||||
|
||||
@@ -6,14 +6,6 @@ const (
|
||||
// Kiro API limit is 10240 bytes, leave room for "..."
|
||||
KiroMaxToolDescLen = 10237
|
||||
|
||||
// ToolCompressionTargetSize is the target total size for compressed tools (20KB).
|
||||
// If tools exceed this size, compression will be applied.
|
||||
ToolCompressionTargetSize = 20 * 1024 // 20KB
|
||||
|
||||
// MinToolDescriptionLength is the minimum description length after compression.
|
||||
// Descriptions will not be shortened below this length.
|
||||
MinToolDescriptionLength = 50
|
||||
|
||||
// ThinkingStartTag is the start tag for thinking blocks in responses.
|
||||
ThinkingStartTag = "<thinking>"
|
||||
|
||||
|
||||
@@ -75,10 +75,18 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh))
|
||||
case "adaptive", "auto":
|
||||
// Adaptive thinking can carry an explicit effort in output_config.effort (Claude 4.6).
|
||||
// Pass through directly; ApplyThinking handles clamping to target model's levels.
|
||||
effort := ""
|
||||
if v := root.Get("output_config.effort"); v.Exists() && v.Type == gjson.String {
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh))
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
@@ -175,7 +183,12 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
|
||||
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
|
||||
toolResultContent, toolResultContentRaw := convertClaudeToolResultContent(part.Get("content"))
|
||||
if toolResultContentRaw {
|
||||
toolResultJSON, _ = sjson.SetRaw(toolResultJSON, "content", toolResultContent)
|
||||
} else {
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", toolResultContent)
|
||||
}
|
||||
toolResults = append(toolResults, toolResultJSON)
|
||||
}
|
||||
return true
|
||||
@@ -366,21 +379,41 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func convertClaudeToolResultContentToString(content gjson.Result) string {
|
||||
func convertClaudeToolResultContent(content gjson.Result) (string, bool) {
|
||||
if !content.Exists() {
|
||||
return ""
|
||||
return "", false
|
||||
}
|
||||
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
return content.String(), false
|
||||
}
|
||||
|
||||
if content.IsArray() {
|
||||
var parts []string
|
||||
contentJSON := "[]"
|
||||
hasImagePart := false
|
||||
content.ForEach(func(_, item gjson.Result) bool {
|
||||
switch {
|
||||
case item.Type == gjson.String:
|
||||
parts = append(parts, item.String())
|
||||
text := item.String()
|
||||
parts = append(parts, text)
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent, _ = sjson.Set(textContent, "text", text)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
|
||||
case item.IsObject() && item.Get("type").String() == "text":
|
||||
text := item.Get("text").String()
|
||||
parts = append(parts, text)
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent, _ = sjson.Set(textContent, "text", text)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
|
||||
case item.IsObject() && item.Get("type").String() == "image":
|
||||
contentItem, ok := convertClaudeContentPart(item)
|
||||
if ok {
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
|
||||
hasImagePart = true
|
||||
} else {
|
||||
parts = append(parts, item.Raw)
|
||||
}
|
||||
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
|
||||
parts = append(parts, item.Get("text").String())
|
||||
default:
|
||||
@@ -389,19 +422,31 @@ func convertClaudeToolResultContentToString(content gjson.Result) string {
|
||||
return true
|
||||
})
|
||||
|
||||
if hasImagePart {
|
||||
return contentJSON, true
|
||||
}
|
||||
|
||||
joined := strings.Join(parts, "\n\n")
|
||||
if strings.TrimSpace(joined) != "" {
|
||||
return joined
|
||||
return joined, false
|
||||
}
|
||||
return content.Raw
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
if content.IsObject() {
|
||||
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
return text.String()
|
||||
if content.Get("type").String() == "image" {
|
||||
contentItem, ok := convertClaudeContentPart(content)
|
||||
if ok {
|
||||
contentJSON := "[]"
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
|
||||
return contentJSON, true
|
||||
}
|
||||
}
|
||||
return content.Raw
|
||||
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
return text.String(), false
|
||||
}
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
return content.Raw
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
@@ -488,6 +488,114 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultTextAndImageContent(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_1",
|
||||
"content": [
|
||||
{"type": "text", "text": "tool ok"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolContent := messages[1].Get("content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "text" {
|
||||
t.Fatalf("Expected first tool content type %q, got %q", "text", got)
|
||||
}
|
||||
if got := toolContent.Get("0.text").String(); got != "tool ok" {
|
||||
t.Fatalf("Expected first tool content text %q, got %q", "tool ok", got)
|
||||
}
|
||||
if got := toolContent.Get("1.type").String(); got != "image_url" {
|
||||
t.Fatalf("Expected second tool content type %q, got %q", "image_url", got)
|
||||
}
|
||||
if got := toolContent.Get("1.image_url.url").String(); got != "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" {
|
||||
t.Fatalf("Unexpected image_url: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultURLImageOnly(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_1",
|
||||
"content": {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": "https://example.com/tool.png"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolContent := messages[1].Get("content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "image_url" {
|
||||
t.Fatalf("Expected tool content type %q, got %q", "image_url", got)
|
||||
}
|
||||
if got := toolContent.Get("0.image_url.url").String(); got != "https://example.com/tool.png" {
|
||||
t.Fatalf("Unexpected image_url: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
|
||||
@@ -22,9 +22,11 @@ var (
|
||||
|
||||
// ConvertOpenAIResponseToAnthropicParams holds parameters for response conversion
|
||||
type ConvertOpenAIResponseToAnthropicParams struct {
|
||||
MessageID string
|
||||
Model string
|
||||
CreatedAt int64
|
||||
MessageID string
|
||||
Model string
|
||||
CreatedAt int64
|
||||
ToolNameMap map[string]string
|
||||
SawToolCall bool
|
||||
// Content accumulator for streaming
|
||||
ContentAccumulator strings.Builder
|
||||
// Tool calls accumulator for streaming
|
||||
@@ -78,6 +80,8 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
MessageID: "",
|
||||
Model: "",
|
||||
CreatedAt: 0,
|
||||
ToolNameMap: nil,
|
||||
SawToolCall: false,
|
||||
ContentAccumulator: strings.Builder{},
|
||||
ToolCallsAccumulator: nil,
|
||||
TextContentBlockStarted: false,
|
||||
@@ -97,6 +101,10 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
if (*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap == nil {
|
||||
(*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap = util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
// Check if this is the [DONE] marker
|
||||
rawStr := strings.TrimSpace(string(rawJSON))
|
||||
if rawStr == "[DONE]" {
|
||||
@@ -111,6 +119,16 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
}
|
||||
}
|
||||
|
||||
func effectiveOpenAIFinishReason(param *ConvertOpenAIResponseToAnthropicParams) string {
|
||||
if param == nil {
|
||||
return ""
|
||||
}
|
||||
if param.SawToolCall {
|
||||
return "tool_calls"
|
||||
}
|
||||
return param.FinishReason
|
||||
}
|
||||
|
||||
// convertOpenAIStreamingChunkToAnthropic converts OpenAI streaming chunk to Anthropic streaming events
|
||||
func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
@@ -197,6 +215,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
}
|
||||
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
param.SawToolCall = true
|
||||
index := int(toolCall.Get("index").Int())
|
||||
blockIndex := param.toolContentBlockIndex(index)
|
||||
|
||||
@@ -215,7 +234,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
// Handle function name
|
||||
if function := toolCall.Get("function"); function.Exists() {
|
||||
if name := function.Get("name"); name.Exists() {
|
||||
accumulator.Name = name.String()
|
||||
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
|
||||
|
||||
stopThinkingContentBlock(param, &results)
|
||||
|
||||
@@ -224,7 +243,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
// Send content_block_start for tool_use
|
||||
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
|
||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
|
||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
|
||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
|
||||
}
|
||||
@@ -246,7 +265,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
// Handle finish_reason (but don't send message_delta/message_stop yet)
|
||||
if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" {
|
||||
reason := finishReason.String()
|
||||
param.FinishReason = reason
|
||||
if param.SawToolCall {
|
||||
param.FinishReason = "tool_calls"
|
||||
} else {
|
||||
param.FinishReason = reason
|
||||
}
|
||||
|
||||
// Send content_block_stop for thinking content if needed
|
||||
if param.ThinkingContentBlockStarted {
|
||||
@@ -294,7 +317,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage)
|
||||
// Send message_delta with usage
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param)))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
@@ -348,7 +371,7 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
|
||||
// If we haven't sent message_delta yet (no usage info was received), send it now
|
||||
if param.FinishReason != "" && !param.MessageDeltaSent {
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param)))
|
||||
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
||||
param.MessageDeltaSent = true
|
||||
}
|
||||
@@ -391,7 +414,7 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
|
||||
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
|
||||
|
||||
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
||||
@@ -531,10 +554,10 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results
|
||||
// Returns:
|
||||
// - string: An Anthropic-compatible JSON response.
|
||||
func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
_ = originalRequestRawJSON
|
||||
_ = requestRawJSON
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", root.Get("model").String())
|
||||
@@ -589,8 +612,8 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
toolCalls.ForEach(func(_, tc gjson.Result) bool {
|
||||
hasToolCall = true
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
|
||||
toolUse, _ = sjson.Set(toolUse, "name", tc.Get("function.name").String())
|
||||
toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String()))
|
||||
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
|
||||
|
||||
argsStr := util.FixJSON(tc.Get("function.arguments").String())
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
@@ -646,8 +669,8 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
hasToolCall = true
|
||||
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
|
||||
|
||||
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
|
||||
24
internal/util/claude_tool_id.go
Normal file
24
internal/util/claude_tool_id.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
|
||||
claudeToolUseIDCounter uint64
|
||||
)
|
||||
|
||||
// SanitizeClaudeToolID ensures the given id conforms to Claude's
|
||||
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
|
||||
// replaced with '_'; an empty result gets a generated fallback.
|
||||
func SanitizeClaudeToolID(id string) string {
|
||||
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
|
||||
if s == "" {
|
||||
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -430,7 +430,7 @@ func removeUnsupportedKeywords(jsonStr string) string {
|
||||
keywords := append(unsupportedConstraints,
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "$id", "additionalProperties",
|
||||
"propertyNames", "patternProperties", // Gemini doesn't support these schema keywords
|
||||
"enumTitles", "prefill", // Claude/OpenCode schema metadata fields unsupported by Gemini
|
||||
"enumTitles", "prefill", "deprecated", // Schema metadata fields unsupported by Gemini
|
||||
)
|
||||
|
||||
deletePaths := make([]string, 0)
|
||||
|
||||
@@ -6,6 +6,7 @@ package util
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -219,3 +220,54 @@ func FixJSON(input string) string {
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func CanonicalToolName(name string) string {
|
||||
canonical := strings.TrimSpace(name)
|
||||
canonical = strings.TrimLeft(canonical, "_")
|
||||
return strings.ToLower(canonical)
|
||||
}
|
||||
|
||||
// ToolNameMapFromClaudeRequest returns a canonical-name -> original-name map extracted from a Claude request.
|
||||
// It is used to restore exact tool name casing for clients that require strict tool name matching (e.g. Claude Code).
|
||||
func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string {
|
||||
if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) {
|
||||
return nil
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return nil
|
||||
}
|
||||
|
||||
toolResults := tools.Array()
|
||||
out := make(map[string]string, len(toolResults))
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
name := strings.TrimSpace(tool.Get("name").String())
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
key := CanonicalToolName(name)
|
||||
if key == "" {
|
||||
return true
|
||||
}
|
||||
if _, exists := out[key]; !exists {
|
||||
out[key] = name
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func MapToolName(toolNameMap map[string]string, name string) string {
|
||||
if name == "" || toolNameMap == nil {
|
||||
return name
|
||||
}
|
||||
if mapped, ok := toolNameMap[CanonicalToolName(name)]; ok && mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -75,6 +76,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
} else if resolvedAuthDir != "" {
|
||||
@@ -92,6 +94,17 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
}
|
||||
ctx := &synthesizer.SynthesisContext{
|
||||
Config: cfg,
|
||||
AuthDir: resolvedAuthDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||
}
|
||||
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
|
||||
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
|
||||
w.fileAuthsByPath[normalizedPath] = pathAuths
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -143,13 +156,14 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
if cfg == nil {
|
||||
if w.config == nil {
|
||||
log.Error("config is nil, cannot add or update client")
|
||||
w.clientsMutex.Unlock()
|
||||
return
|
||||
}
|
||||
if w.fileAuthsByPath == nil {
|
||||
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
|
||||
}
|
||||
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
||||
w.clientsMutex.Unlock()
|
||||
@@ -177,34 +191,86 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after add/update")
|
||||
w.reloadCallback(cfg)
|
||||
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
|
||||
for id, a := range w.fileAuthsByPath[normalized] {
|
||||
oldByID[id] = a
|
||||
}
|
||||
|
||||
// Build synthesized auth entries for this single file only.
|
||||
sctx := &synthesizer.SynthesisContext{
|
||||
Config: w.config,
|
||||
AuthDir: w.authDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||
}
|
||||
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
|
||||
newByID := authSliceToMap(generated)
|
||||
if len(newByID) > 0 {
|
||||
w.fileAuthsByPath[normalized] = newByID
|
||||
} else {
|
||||
delete(w.fileAuthsByPath, normalized)
|
||||
}
|
||||
updates := w.computePerPathUpdatesLocked(oldByID, newByID)
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
||||
w.dispatchAuthUpdates(updates)
|
||||
}
|
||||
|
||||
func (w *Watcher) removeClient(path string) {
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
|
||||
for id, a := range w.fileAuthsByPath[normalized] {
|
||||
oldByID[id] = a
|
||||
}
|
||||
delete(w.lastAuthHashes, normalized)
|
||||
delete(w.lastAuthContents, normalized)
|
||||
delete(w.fileAuthsByPath, normalized)
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
updates := w.computePerPathUpdatesLocked(oldByID, map[string]*coreauth.Auth{})
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after removal")
|
||||
w.reloadCallback(cfg)
|
||||
}
|
||||
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
||||
w.dispatchAuthUpdates(updates)
|
||||
}
|
||||
|
||||
func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate {
|
||||
if w.currentAuths == nil {
|
||||
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
updates := make([]AuthUpdate, 0, len(oldByID)+len(newByID))
|
||||
for id, newAuth := range newByID {
|
||||
existing, ok := w.currentAuths[id]
|
||||
if !ok {
|
||||
w.currentAuths[id] = newAuth.Clone()
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: newAuth.Clone()})
|
||||
continue
|
||||
}
|
||||
if !authEqual(existing, newAuth) {
|
||||
w.currentAuths[id] = newAuth.Clone()
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: newAuth.Clone()})
|
||||
}
|
||||
}
|
||||
for id := range oldByID {
|
||||
if _, stillExists := newByID[id]; stillExists {
|
||||
continue
|
||||
}
|
||||
delete(w.currentAuths, id)
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
|
||||
}
|
||||
return updates
|
||||
}
|
||||
|
||||
func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
|
||||
byID := make(map[string]*coreauth.Auth, len(auths))
|
||||
for _, a := range auths {
|
||||
if a == nil || strings.TrimSpace(a.ID) == "" {
|
||||
continue
|
||||
}
|
||||
byID[a.ID] = a
|
||||
}
|
||||
return byID
|
||||
}
|
||||
|
||||
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||
@@ -303,3 +369,79 @@ func (w *Watcher) persistAuthAsync(message string, paths ...string) {
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *Watcher) stopServerUpdateTimer() {
|
||||
w.serverUpdateMu.Lock()
|
||||
defer w.serverUpdateMu.Unlock()
|
||||
if w.serverUpdateTimer != nil {
|
||||
w.serverUpdateTimer.Stop()
|
||||
w.serverUpdateTimer = nil
|
||||
}
|
||||
w.serverUpdatePend = false
|
||||
}
|
||||
|
||||
func (w *Watcher) triggerServerUpdate(cfg *config.Config) {
|
||||
if w == nil || w.reloadCallback == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
if w.stopped.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
w.serverUpdateMu.Lock()
|
||||
if w.serverUpdateLast.IsZero() || now.Sub(w.serverUpdateLast) >= serverUpdateDebounce {
|
||||
w.serverUpdateLast = now
|
||||
if w.serverUpdateTimer != nil {
|
||||
w.serverUpdateTimer.Stop()
|
||||
w.serverUpdateTimer = nil
|
||||
}
|
||||
w.serverUpdatePend = false
|
||||
w.serverUpdateMu.Unlock()
|
||||
w.reloadCallback(cfg)
|
||||
return
|
||||
}
|
||||
|
||||
if w.serverUpdatePend {
|
||||
w.serverUpdateMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
delay := serverUpdateDebounce - now.Sub(w.serverUpdateLast)
|
||||
if delay < 10*time.Millisecond {
|
||||
delay = 10 * time.Millisecond
|
||||
}
|
||||
w.serverUpdatePend = true
|
||||
if w.serverUpdateTimer != nil {
|
||||
w.serverUpdateTimer.Stop()
|
||||
w.serverUpdateTimer = nil
|
||||
}
|
||||
var timer *time.Timer
|
||||
timer = time.AfterFunc(delay, func() {
|
||||
if w.stopped.Load() {
|
||||
return
|
||||
}
|
||||
w.clientsMutex.RLock()
|
||||
latestCfg := w.config
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
w.serverUpdateMu.Lock()
|
||||
if w.serverUpdateTimer != timer || !w.serverUpdatePend {
|
||||
w.serverUpdateMu.Unlock()
|
||||
return
|
||||
}
|
||||
w.serverUpdateTimer = nil
|
||||
w.serverUpdatePend = false
|
||||
if latestCfg == nil || w.reloadCallback == nil || w.stopped.Load() {
|
||||
w.serverUpdateMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
w.serverUpdateLast = time.Now()
|
||||
w.serverUpdateMu.Unlock()
|
||||
w.reloadCallback(latestCfg)
|
||||
})
|
||||
w.serverUpdateTimer = timer
|
||||
w.serverUpdateMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -127,7 +127,8 @@ func (w *Watcher) reloadConfig() bool {
|
||||
}
|
||||
|
||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelAlias, newConfig.OAuthModelAlias))
|
||||
retryConfigChanged := oldConfig != nil && (oldConfig.RequestRetry != newConfig.RequestRetry || oldConfig.MaxRetryInterval != newConfig.MaxRetryInterval || oldConfig.MaxRetryCredentials != newConfig.MaxRetryCredentials)
|
||||
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelAlias, newConfig.OAuthModelAlias) || retryConfigChanged)
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||
|
||||
@@ -54,6 +54,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryCredentials != newCfg.MaxRetryCredentials {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-credentials: %d -> %d", oldCfg.MaxRetryCredentials, newCfg.MaxRetryCredentials))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
@@ -301,6 +304,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].headers: updated", i))
|
||||
}
|
||||
|
||||
@@ -223,6 +223,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryCredentials: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
@@ -246,6 +247,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryCredentials: 3,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
@@ -283,6 +285,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
expectContains(t, details, "disable-cooling: false -> true")
|
||||
expectContains(t, details, "request-log: false -> true")
|
||||
expectContains(t, details, "request-retry: 1 -> 2")
|
||||
expectContains(t, details, "max-retry-credentials: 1 -> 3")
|
||||
expectContains(t, details, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, details, "ws-auth: false -> true")
|
||||
@@ -309,6 +312,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryCredentials: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
@@ -361,6 +365,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryCredentials: 3,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
@@ -419,6 +424,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
expectContains(t, changes, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, changes, "disable-cooling: false -> true")
|
||||
expectContains(t, changes, "request-retry: 1 -> 2")
|
||||
expectContains(t, changes, "max-retry-credentials: 1 -> 3")
|
||||
expectContains(t, changes, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, changes, "ws-auth: false -> true")
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
var snapshotCoreAuthsFunc = snapshotCoreAuths
|
||||
|
||||
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||
w.clientsMutex.Lock()
|
||||
defer w.clientsMutex.Unlock()
|
||||
@@ -76,7 +78,11 @@ func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||
}
|
||||
|
||||
func (w *Watcher) refreshAuthState(force bool) {
|
||||
auths := w.SnapshotCoreAuths()
|
||||
w.clientsMutex.RLock()
|
||||
cfg := w.config
|
||||
authDir := w.authDir
|
||||
w.clientsMutex.RUnlock()
|
||||
auths := snapshotCoreAuthsFunc(cfg, authDir)
|
||||
w.clientsMutex.Lock()
|
||||
if len(w.runtimeAuths) > 0 {
|
||||
for _, a := range w.runtimeAuths {
|
||||
|
||||
@@ -319,7 +319,7 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "apikey")
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, compat.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -35,9 +36,6 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
||||
return out, nil
|
||||
}
|
||||
|
||||
now := ctx.Now
|
||||
cfg := ctx.Config
|
||||
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
@@ -51,95 +49,120 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
||||
if errRead != nil || len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
var metadata map[string]any
|
||||
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
||||
auths := synthesizeFileAuths(ctx, full, data)
|
||||
if len(auths) == 0 {
|
||||
continue
|
||||
}
|
||||
t, _ := metadata["type"].(string)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(t)
|
||||
if provider == "gemini" {
|
||||
provider = "gemini-cli"
|
||||
}
|
||||
label := provider
|
||||
if email, _ := metadata["email"].(string); email != "" {
|
||||
label = email
|
||||
}
|
||||
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
||||
id := full
|
||||
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if p, ok := metadata["proxy_url"].(string); ok {
|
||||
proxyURL = p
|
||||
}
|
||||
|
||||
prefix := ""
|
||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||
trimmed := strings.TrimSpace(rawPrefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||
prefix = trimmed
|
||||
}
|
||||
}
|
||||
|
||||
disabled, _ := metadata["disabled"].(bool)
|
||||
status := coreauth.StatusActive
|
||||
if disabled {
|
||||
status = coreauth.StatusDisabled
|
||||
}
|
||||
|
||||
// Read per-account excluded models from the OAuth JSON file
|
||||
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: label,
|
||||
Prefix: prefix,
|
||||
Status: status,
|
||||
Disabled: disabled,
|
||||
Attributes: map[string]string{
|
||||
"source": full,
|
||||
"path": full,
|
||||
},
|
||||
ProxyURL: proxyURL,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
// Read priority from auth file
|
||||
if rawPriority, ok := metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
a.Attributes["priority"] = strconv.Itoa(int(v))
|
||||
case string:
|
||||
priority := strings.TrimSpace(v)
|
||||
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
|
||||
a.Attributes["priority"] = priority
|
||||
}
|
||||
}
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, auths...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// SynthesizeAuthFile generates Auth entries for one auth JSON file payload.
|
||||
// It shares exactly the same mapping behavior as FileSynthesizer.Synthesize.
|
||||
func SynthesizeAuthFile(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
|
||||
return synthesizeFileAuths(ctx, fullPath, data)
|
||||
}
|
||||
|
||||
func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
|
||||
if ctx == nil || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := ctx.Now
|
||||
cfg := ctx.Config
|
||||
var metadata map[string]any
|
||||
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
||||
return nil
|
||||
}
|
||||
t, _ := metadata["type"].(string)
|
||||
if t == "" {
|
||||
return nil
|
||||
}
|
||||
provider := strings.ToLower(t)
|
||||
if provider == "gemini" {
|
||||
provider = "gemini-cli"
|
||||
}
|
||||
label := provider
|
||||
if email, _ := metadata["email"].(string); email != "" {
|
||||
label = email
|
||||
}
|
||||
// Use relative path under authDir as ID to stay consistent with the file-based token store.
|
||||
id := fullPath
|
||||
if strings.TrimSpace(ctx.AuthDir) != "" {
|
||||
if rel, errRel := filepath.Rel(ctx.AuthDir, fullPath); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
id = strings.ToLower(id)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if p, ok := metadata["proxy_url"].(string); ok {
|
||||
proxyURL = p
|
||||
}
|
||||
|
||||
prefix := ""
|
||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||
trimmed := strings.TrimSpace(rawPrefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||
prefix = trimmed
|
||||
}
|
||||
}
|
||||
|
||||
disabled, _ := metadata["disabled"].(bool)
|
||||
status := coreauth.StatusActive
|
||||
if disabled {
|
||||
status = coreauth.StatusDisabled
|
||||
}
|
||||
|
||||
// Read per-account excluded models from the OAuth JSON file.
|
||||
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: label,
|
||||
Prefix: prefix,
|
||||
Status: status,
|
||||
Disabled: disabled,
|
||||
Attributes: map[string]string{
|
||||
"source": fullPath,
|
||||
"path": fullPath,
|
||||
},
|
||||
ProxyURL: proxyURL,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
// Read priority from auth file.
|
||||
if rawPriority, ok := metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
a.Attributes["priority"] = strconv.Itoa(int(v))
|
||||
case string:
|
||||
priority := strings.TrimSpace(v)
|
||||
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
|
||||
a.Attributes["priority"] = priority
|
||||
}
|
||||
}
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
|
||||
}
|
||||
out := make([]*coreauth.Auth, 0, 1+len(virtuals))
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
return out
|
||||
}
|
||||
}
|
||||
return []*coreauth.Auth{a}
|
||||
}
|
||||
|
||||
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
|
||||
// It disables the primary auth and creates one virtual auth per project.
|
||||
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
@@ -35,10 +36,16 @@ type Watcher struct {
|
||||
clientsMutex sync.RWMutex
|
||||
configReloadMu sync.Mutex
|
||||
configReloadTimer *time.Timer
|
||||
serverUpdateMu sync.Mutex
|
||||
serverUpdateTimer *time.Timer
|
||||
serverUpdateLast time.Time
|
||||
serverUpdatePend bool
|
||||
stopped atomic.Bool
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastAuthContents map[string]*coreauth.Auth
|
||||
fileAuthsByPath map[string]map[string]*coreauth.Auth
|
||||
lastRemoveTimes map[string]time.Time
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
@@ -76,6 +83,7 @@ const (
|
||||
replaceCheckDelay = 50 * time.Millisecond
|
||||
configReloadDebounce = 150 * time.Millisecond
|
||||
authRemoveDebounceWindow = 1 * time.Second
|
||||
serverUpdateDebounce = 1 * time.Second
|
||||
)
|
||||
|
||||
// NewWatcher creates a new file watcher instance
|
||||
@@ -85,11 +93,12 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
|
||||
return nil, errNewWatcher
|
||||
}
|
||||
w := &Watcher{
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
reloadCallback: reloadCallback,
|
||||
watcher: watcher,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
reloadCallback: reloadCallback,
|
||||
watcher: watcher,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
|
||||
}
|
||||
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
||||
if store := sdkAuth.GetTokenStore(); store != nil {
|
||||
@@ -114,8 +123,10 @@ func (w *Watcher) Start(ctx context.Context) error {
|
||||
|
||||
// Stop stops the file watcher
|
||||
func (w *Watcher) Stop() error {
|
||||
w.stopped.Store(true)
|
||||
w.stopDispatch()
|
||||
w.stopConfigReloadTimer()
|
||||
w.stopServerUpdateTimer()
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -406,8 +406,8 @@ func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
|
||||
|
||||
w.addOrUpdateClient(authFile)
|
||||
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", got)
|
||||
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||
t.Fatalf("expected no reload callback for auth update, got %d", got)
|
||||
}
|
||||
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
|
||||
normalized := w.normalizeAuthPath(authFile)
|
||||
@@ -436,8 +436,150 @@ func TestRemoveClientRemovesHash(t *testing.T) {
|
||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||
t.Fatal("expected hash to be removed after deletion")
|
||||
}
|
||||
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||
t.Fatalf("expected no reload callback for auth removal, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authFile := filepath.Join(tmpDir, "sample.json")
|
||||
if err := os.WriteFile(authFile, []byte(`{"type":"codex","email":"u@example.com"}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to create auth file: %v", err)
|
||||
}
|
||||
|
||||
origSnapshot := snapshotCoreAuthsFunc
|
||||
var snapshotCalls int32
|
||||
snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth {
|
||||
atomic.AddInt32(&snapshotCalls, 1)
|
||||
return origSnapshot(cfg, authDir)
|
||||
}
|
||||
defer func() { snapshotCoreAuthsFunc = origSnapshot }()
|
||||
|
||||
w := &Watcher{
|
||||
authDir: tmpDir,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
lastAuthContents: make(map[string]*coreauth.Auth),
|
||||
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
|
||||
}
|
||||
w.SetConfig(&config.Config{AuthDir: tmpDir})
|
||||
|
||||
w.addOrUpdateClient(authFile)
|
||||
w.removeClient(authFile)
|
||||
|
||||
if got := atomic.LoadInt32(&snapshotCalls); got != 0 {
|
||||
t.Fatalf("expected auth file events to avoid full snapshot, got %d calls", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSliceToMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
valid1 := &coreauth.Auth{ID: "a"}
|
||||
valid2 := &coreauth.Auth{ID: "b"}
|
||||
dupOld := &coreauth.Auth{ID: "dup", Label: "old"}
|
||||
dupNew := &coreauth.Auth{ID: "dup", Label: "new"}
|
||||
empty := &coreauth.Auth{ID: " "}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in []*coreauth.Auth
|
||||
want map[string]*coreauth.Auth
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
in: nil,
|
||||
want: map[string]*coreauth.Auth{},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
in: []*coreauth.Auth{},
|
||||
want: map[string]*coreauth.Auth{},
|
||||
},
|
||||
{
|
||||
name: "filters invalid auths",
|
||||
in: []*coreauth.Auth{nil, empty},
|
||||
want: map[string]*coreauth.Auth{},
|
||||
},
|
||||
{
|
||||
name: "keeps valid auths",
|
||||
in: []*coreauth.Auth{valid1, nil, valid2},
|
||||
want: map[string]*coreauth.Auth{"a": valid1, "b": valid2},
|
||||
},
|
||||
{
|
||||
name: "last duplicate wins",
|
||||
in: []*coreauth.Auth{dupOld, dupNew},
|
||||
want: map[string]*coreauth.Auth{"dup": dupNew},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := authSliceToMap(tc.in)
|
||||
if len(tc.want) == 0 {
|
||||
if got == nil {
|
||||
t.Fatal("expected empty map, got nil")
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty map, got %#v", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(got) != len(tc.want) {
|
||||
t.Fatalf("unexpected map length: got %d, want %d", len(got), len(tc.want))
|
||||
}
|
||||
for id, wantAuth := range tc.want {
|
||||
gotAuth, ok := got[id]
|
||||
if !ok {
|
||||
t.Fatalf("missing id %q in result map", id)
|
||||
}
|
||||
if !authEqual(gotAuth, wantAuth) {
|
||||
t.Fatalf("unexpected auth for id %q: got %#v, want %#v", id, gotAuth, wantAuth)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriggerServerUpdateCancelsPendingTimerOnImmediate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{AuthDir: tmpDir}
|
||||
|
||||
var reloads int32
|
||||
w := &Watcher{
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.SetConfig(cfg)
|
||||
|
||||
w.serverUpdateMu.Lock()
|
||||
w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce - 100*time.Millisecond))
|
||||
w.serverUpdateMu.Unlock()
|
||||
w.triggerServerUpdate(cfg)
|
||||
|
||||
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||
t.Fatalf("expected no immediate reload, got %d", got)
|
||||
}
|
||||
|
||||
w.serverUpdateMu.Lock()
|
||||
if !w.serverUpdatePend || w.serverUpdateTimer == nil {
|
||||
w.serverUpdateMu.Unlock()
|
||||
t.Fatal("expected a pending server update timer")
|
||||
}
|
||||
w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce + 10*time.Millisecond))
|
||||
w.serverUpdateMu.Unlock()
|
||||
|
||||
w.triggerServerUpdate(cfg)
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", got)
|
||||
t.Fatalf("expected immediate reload once, got %d", got)
|
||||
}
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected pending timer to be cancelled, got %d reloads", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -655,8 +797,8 @@ func TestHandleEventRemovesAuthFile(t *testing.T) {
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected reload callback once, got %d", reloads)
|
||||
if atomic.LoadInt32(&reloads) != 0 {
|
||||
t.Fatalf("expected no reload callback for auth removal, got %d", reloads)
|
||||
}
|
||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||
t.Fatal("expected hash entry to be removed")
|
||||
@@ -853,8 +995,8 @@ func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
|
||||
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected auth write to trigger reload callback, got %d", reloads)
|
||||
if atomic.LoadInt32(&reloads) != 0 {
|
||||
t.Fatalf("expected auth write to avoid global reload callback, got %d", reloads)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -950,8 +1092,8 @@ func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
|
||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads)
|
||||
if atomic.LoadInt32(&reloads) != 0 {
|
||||
t.Fatalf("expected changed atomic replace to avoid global reload, got %d", reloads)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1005,8 +1147,8 @@ func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
|
||||
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
|
||||
|
||||
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||
if atomic.LoadInt32(&reloads) != 1 {
|
||||
t.Fatalf("expected known remove to trigger reload, got %d", reloads)
|
||||
if atomic.LoadInt32(&reloads) != 0 {
|
||||
t.Fatalf("expected known remove to avoid global reload, got %d", reloads)
|
||||
}
|
||||
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||
t.Fatal("expected known auth hash to be deleted")
|
||||
@@ -1239,6 +1381,67 @@ func TestReloadConfigFiltersAffectedOAuthProviders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadConfigTriggersCallbackForMaxRetryCredentialsChange(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
oldCfg := &config.Config{
|
||||
AuthDir: authDir,
|
||||
MaxRetryCredentials: 0,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 5,
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
AuthDir: authDir,
|
||||
MaxRetryCredentials: 2,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 5,
|
||||
}
|
||||
data, errMarshal := yaml.Marshal(newCfg)
|
||||
if errMarshal != nil {
|
||||
t.Fatalf("failed to marshal config: %v", errMarshal)
|
||||
}
|
||||
if errWrite := os.WriteFile(configPath, data, 0o644); errWrite != nil {
|
||||
t.Fatalf("failed to write config: %v", errWrite)
|
||||
}
|
||||
|
||||
callbackCalls := 0
|
||||
callbackMaxRetryCredentials := -1
|
||||
w := &Watcher{
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
reloadCallback: func(cfg *config.Config) {
|
||||
callbackCalls++
|
||||
if cfg != nil {
|
||||
callbackMaxRetryCredentials = cfg.MaxRetryCredentials
|
||||
}
|
||||
},
|
||||
}
|
||||
w.SetConfig(oldCfg)
|
||||
|
||||
if ok := w.reloadConfig(); !ok {
|
||||
t.Fatal("expected reloadConfig to succeed")
|
||||
}
|
||||
|
||||
if callbackCalls != 1 {
|
||||
t.Fatalf("expected reload callback to be called once, got %d", callbackCalls)
|
||||
}
|
||||
if callbackMaxRetryCredentials != 2 {
|
||||
t.Fatalf("expected callback MaxRetryCredentials=2, got %d", callbackMaxRetryCredentials)
|
||||
}
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
if w.config == nil || w.config.MaxRetryCredentials != 2 {
|
||||
t.Fatalf("expected watcher config MaxRetryCredentials=2, got %+v", w.config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartFailsWhenAuthDirMissing(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
@@ -14,7 +14,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -26,7 +30,6 @@ const (
|
||||
wsRequestTypeAppend = "response.append"
|
||||
wsEventTypeError = "error"
|
||||
wsEventTypeCompleted = "response.completed"
|
||||
wsEventTypeDone = "response.done"
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
@@ -101,11 +104,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
// )
|
||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||
|
||||
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
||||
allowIncrementalInputWithPreviousResponseID := false
|
||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
}
|
||||
} else {
|
||||
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
if requestModelName == "" {
|
||||
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
}
|
||||
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||
}
|
||||
|
||||
var requestJSON []byte
|
||||
@@ -140,6 +149,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
|
||||
requestJSON = updated
|
||||
}
|
||||
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
|
||||
updatedLastRequest = updated
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
lastResponseOutput = []byte("[]")
|
||||
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
|
||||
wsTerminateErr = errWrite
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
@@ -340,6 +365,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
||||
if h == nil || h.AuthManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
resolvedModelName := modelName
|
||||
initialSuffix := thinking.ParseSuffix(modelName)
|
||||
if initialSuffix.ModelName == "auto" {
|
||||
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
||||
if initialSuffix.HasSuffix {
|
||||
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
||||
} else {
|
||||
resolvedModelName = resolvedBase
|
||||
}
|
||||
} else {
|
||||
resolvedModelName = util.ResolveAutoModel(modelName)
|
||||
}
|
||||
|
||||
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||
providers := util.GetProviderName(baseModel)
|
||||
if len(providers) == 0 && baseModel != resolvedModelName {
|
||||
providers = util.GetProviderName(resolvedModelName)
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for i := 0; i < len(providers); i++ {
|
||||
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
providerSet[providerKey] = struct{}{}
|
||||
}
|
||||
if len(providerSet) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
modelKey := baseModel
|
||||
if modelKey == "" {
|
||||
modelKey = strings.TrimSpace(resolvedModelName)
|
||||
}
|
||||
registryRef := registry.GetGlobalRegistry()
|
||||
now := time.Now()
|
||||
auths := h.AuthManager.List()
|
||||
for i := 0; i < len(auths); i++ {
|
||||
auth := auths[i]
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
||||
if _, ok := providerSet[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
|
||||
continue
|
||||
}
|
||||
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
|
||||
continue
|
||||
}
|
||||
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
|
||||
return false
|
||||
}
|
||||
if modelName != "" && len(auth.ModelStates) > 0 {
|
||||
state, ok := auth.ModelStates[modelName]
|
||||
if (!ok || state == nil) && modelName != "" {
|
||||
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
|
||||
if baseModel != "" && baseModel != modelName {
|
||||
state, ok = auth.ModelStates[baseModel]
|
||||
}
|
||||
}
|
||||
if ok && state != nil {
|
||||
if state.Status == coreauth.StatusDisabled {
|
||||
return false
|
||||
}
|
||||
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
|
||||
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
||||
return false
|
||||
}
|
||||
generateResult := gjson.GetBytes(rawJSON, "generate")
|
||||
return generateResult.Exists() && !generateResult.Bool()
|
||||
}
|
||||
|
||||
func writeResponsesWebsocketSyntheticPrewarm(
|
||||
c *gin.Context,
|
||||
conn *websocket.Conn,
|
||||
requestJSON []byte,
|
||||
wsBodyLog *strings.Builder,
|
||||
sessionID string,
|
||||
) error {
|
||||
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
|
||||
if errPayloads != nil {
|
||||
return errPayloads
|
||||
}
|
||||
for i := 0; i < len(payloads); i++ {
|
||||
markAPIResponseTimestamp(c)
|
||||
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
// sessionID,
|
||||
// websocket.TextMessage,
|
||||
// websocketPayloadEventType(payloads[i]),
|
||||
// websocketPayloadPreview(payloads[i]),
|
||||
// )
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(payloads[i]),
|
||||
errWrite,
|
||||
)
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
|
||||
responseID := "resp_prewarm_" + uuid.NewString()
|
||||
createdAt := time.Now().Unix()
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
|
||||
|
||||
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||
var errSet error
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
if modelName != "" {
|
||||
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
if modelName != "" {
|
||||
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
return [][]byte{createdPayload, completedPayload}, nil
|
||||
}
|
||||
|
||||
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||
existingRaw = strings.TrimSpace(existingRaw)
|
||||
appendRaw = strings.TrimSpace(appendRaw)
|
||||
@@ -469,9 +680,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
for i := range payloads {
|
||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||
if eventType == wsEventTypeCompleted {
|
||||
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
||||
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
||||
|
||||
completed = true
|
||||
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||
}
|
||||
@@ -554,47 +762,63 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
|
||||
}
|
||||
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
payload := map[string]any{
|
||||
"type": wsEventTypeError,
|
||||
"status": status,
|
||||
payload := []byte(`{}`)
|
||||
var errSet error
|
||||
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
payload, errSet = sjson.SetBytes(payload, "status", status)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
|
||||
if errMsg != nil && errMsg.Addon != nil {
|
||||
headers := map[string]any{}
|
||||
headers := []byte(`{}`)
|
||||
hasHeaders := false
|
||||
for key, values := range errMsg.Addon {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
headers[key] = values[0]
|
||||
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
|
||||
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
hasHeaders = true
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
payload["headers"] = headers
|
||||
}
|
||||
}
|
||||
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
var decoded map[string]any
|
||||
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
||||
if inner, ok := decoded["error"]; ok {
|
||||
payload["error"] = inner
|
||||
} else {
|
||||
payload["error"] = decoded
|
||||
if hasHeaders {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := payload["error"]; !ok {
|
||||
payload["error"] = map[string]any{
|
||||
"type": "server_error",
|
||||
"message": errText,
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
errorNode := gjson.GetBytes(body, "error")
|
||||
if errorNode.Exists() {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
|
||||
} else {
|
||||
payload, errSet = sjson.SetRawBytes(payload, "error", body)
|
||||
}
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !gjson.GetBytes(payload, "error").Exists() {
|
||||
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
}
|
||||
}
|
||||
return data, conn.WriteMessage(websocket.TextMessage, data)
|
||||
|
||||
return payload, conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
|
||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||
|
||||
@@ -2,15 +2,57 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type websocketCaptureExecutor struct {
|
||||
streamCalls int
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.streamCalls++
|
||||
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
|
||||
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
|
||||
close(chunks)
|
||||
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
@@ -247,3 +289,206 @@ func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
serverErrCh <- errClose
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
ctx.Request = r
|
||||
|
||||
data := make(chan []byte, 1)
|
||||
errCh := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
|
||||
close(data)
|
||||
close(errCh)
|
||||
|
||||
var bodyLog strings.Builder
|
||||
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
|
||||
ctx,
|
||||
conn,
|
||||
func(...interface{}) {},
|
||||
data,
|
||||
errCh,
|
||||
&bodyLog,
|
||||
"session-1",
|
||||
)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
|
||||
serverErrCh <- errors.New("completed output not captured")
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read websocket message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if strings.Contains(string(payload), "response.done") {
|
||||
t.Fatalf("payload unexpectedly rewrote completed event: %s", payload)
|
||||
}
|
||||
|
||||
if errServer := <-serverErrCh; errServer != nil {
|
||||
t.Fatalf("server error: %v", errServer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-ws",
|
||||
Provider: "test-provider",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{"websockets": "true"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
|
||||
t.Fatalf("expected websocket-capable upstream for test-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
executor := &websocketCaptureExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
|
||||
if errWrite != nil {
|
||||
t.Fatalf("write prewarm websocket message: %v", errWrite)
|
||||
}
|
||||
|
||||
_, createdPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read prewarm created message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
|
||||
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
|
||||
}
|
||||
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
|
||||
if prewarmResponseID == "" {
|
||||
t.Fatalf("prewarm response id is empty")
|
||||
}
|
||||
if executor.streamCalls != 0 {
|
||||
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
|
||||
}
|
||||
|
||||
_, completedPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read prewarm completed message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
|
||||
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
|
||||
}
|
||||
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
|
||||
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
|
||||
}
|
||||
|
||||
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
|
||||
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
|
||||
if errWrite != nil {
|
||||
t.Fatalf("write follow-up websocket message: %v", errWrite)
|
||||
}
|
||||
|
||||
_, upstreamPayload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read upstream completed message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if executor.streamCalls != 1 {
|
||||
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
|
||||
}
|
||||
if len(executor.payloads) != 1 {
|
||||
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
|
||||
}
|
||||
forwarded := executor.payloads[0]
|
||||
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
|
||||
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
|
||||
}
|
||||
if gjson.GetBytes(forwarded, "generate").Exists() {
|
||||
t.Fatalf("generate leaked upstream: %s", forwarded)
|
||||
}
|
||||
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
|
||||
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
|
||||
}
|
||||
input := gjson.GetBytes(forwarded, "input").Array()
|
||||
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected forwarded input: %s", forwarded)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -266,14 +267,17 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
}
|
||||
|
||||
func (s *FileTokenStore) idFor(path, baseDir string) string {
|
||||
if baseDir == "" {
|
||||
return path
|
||||
id := path
|
||||
if baseDir != "" {
|
||||
if rel, errRel := filepath.Rel(baseDir, path); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
}
|
||||
rel, err := filepath.Rel(baseDir, path)
|
||||
if err != nil {
|
||||
return path
|
||||
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
|
||||
if runtime.GOOS == "windows" {
|
||||
id = strings.ToLower(id)
|
||||
}
|
||||
return rel
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *FileTokenStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, error) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user