mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-23 01:38:01 +00:00
Compare commits
151 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3c59165d7 | ||
|
|
f81acd0760 | ||
|
|
636da4c932 | ||
|
|
cccb77b552 | ||
|
|
2bd646ad70 | ||
|
|
56073ded69 | ||
|
|
9738a53f49 | ||
|
|
be3f8dbf7e | ||
|
|
9c6c3612a8 | ||
|
|
19e1a4447a | ||
|
|
7c2ad4cda2 | ||
|
|
fb95813fbf | ||
|
|
db63f9b5d6 | ||
|
|
25f6c4a250 | ||
|
|
b24ae74216 | ||
|
|
59ad8f40dc | ||
|
|
ff03dc6a2c | ||
|
|
dc7187ca5b | ||
|
|
b1dcff778c | ||
|
|
cef2aeeb08 | ||
|
|
bcd1e8cc34 | ||
|
|
198b3f4a40 | ||
|
|
9fee7f488e | ||
|
|
1b46d39b8b | ||
|
|
c1241a98e2 | ||
|
|
8d8f5970ee | ||
|
|
f90120f846 | ||
|
|
0b94d36c4a | ||
|
|
152c310bb7 | ||
|
|
f6bbca35ab | ||
|
|
c8cee6a209 | ||
|
|
b5701f416b | ||
|
|
4b1a404fcb | ||
|
|
b93cce5412 | ||
|
|
c6cb24039d | ||
|
|
5382408489 | ||
|
|
67669196ed | ||
|
|
58fd9bf964 | ||
|
|
7b3dfc67bc | ||
|
|
cdd24052d3 | ||
|
|
733fd8edab | ||
|
|
af27f2b8bc | ||
|
|
2e1925d762 | ||
|
|
77254bd074 | ||
|
|
5b6342e6ac | ||
|
|
3960c93d51 | ||
|
|
339a81b650 | ||
|
|
560c020477 | ||
|
|
aec65e3be3 | ||
|
|
f44f0702f8 | ||
|
|
b76b79068f | ||
|
|
34c8ccb961 | ||
|
|
d08e164af3 | ||
|
|
8178efaeda | ||
|
|
86d5db472a | ||
|
|
020d36f6e8 | ||
|
|
1db23979e8 | ||
|
|
c3d5dbe96f | ||
|
|
5484489406 | ||
|
|
0ac52da460 | ||
|
|
817cebb321 | ||
|
|
683f3709d6 | ||
|
|
dbd42a42b2 | ||
|
|
ec24baf757 | ||
|
|
dea3e74d35 | ||
|
|
a6c3042e34 | ||
|
|
861537c9bd | ||
|
|
8c92cb0883 | ||
|
|
89d7be9525 | ||
|
|
2b79d7f22f | ||
|
|
2bb686f594 | ||
|
|
163fe287ce | ||
|
|
70988d387b | ||
|
|
52058a1659 | ||
|
|
df5595a0c9 | ||
|
|
ddaa9d2436 | ||
|
|
7b7b258c38 | ||
|
|
a00f774f5a | ||
|
|
9daf1ba8b5 | ||
|
|
76f2359637 | ||
|
|
dcb1c9be8a | ||
|
|
a24f4ace78 | ||
|
|
c631df8c3b | ||
|
|
54c3eb1b1e | ||
|
|
bb28cd26ad | ||
|
|
046865461e | ||
|
|
cf74ed2f0c | ||
|
|
e333fbea3d | ||
|
|
efbe36d1d4 | ||
|
|
8553cfa40e | ||
|
|
30d5c95b26 | ||
|
|
d1e3195e6f | ||
|
|
05a35662ae | ||
|
|
ce53d3a287 | ||
|
|
4cc99e7449 | ||
|
|
71773fe032 | ||
|
|
a1e0fa0f39 | ||
|
|
fc2f0b6983 | ||
|
|
5c9997cdac | ||
|
|
6f81046730 | ||
|
|
0687472d01 | ||
|
|
7739738fb3 | ||
|
|
99d1ce247b | ||
|
|
f5941a411c | ||
|
|
ba672bbd07 | ||
|
|
d9c6627a53 | ||
|
|
2e9907c3ac | ||
|
|
90afb9cb73 | ||
|
|
d0cc0cd9a5 | ||
|
|
338321e553 | ||
|
|
182b31963a | ||
|
|
4f48e5254a | ||
|
|
15dd5db1d7 | ||
|
|
424711b718 | ||
|
|
91a2b1f0b4 | ||
|
|
2b134fc378 | ||
|
|
b9153719b0 | ||
|
|
631e5c8331 | ||
|
|
e9c60a0a67 | ||
|
|
98a1bb5a7f | ||
|
|
ca90487a8c | ||
|
|
1042489f85 | ||
|
|
38277c1ea6 | ||
|
|
ee0c24628f | ||
|
|
3a18f6fcca | ||
|
|
099e734a02 | ||
|
|
a52da26b5d | ||
|
|
522a68a4ea | ||
|
|
a02eda54d0 | ||
|
|
97ef633c57 | ||
|
|
dae8463ba1 | ||
|
|
7c1299922e | ||
|
|
ddcf1f279d | ||
|
|
7e6bb8fdc5 | ||
|
|
9cee8ef87b | ||
|
|
93fb841bcb | ||
|
|
0c05131aeb | ||
|
|
5ebc58fab4 | ||
|
|
2b609dd891 | ||
|
|
a8cbc68c3e | ||
|
|
11a795a01c | ||
|
|
242aecd924 | ||
|
|
ce8cc1ba33 | ||
|
|
97fdd2e088 | ||
|
|
553d6f50ea | ||
|
|
dd44413ba5 | ||
|
|
10fa0f2062 | ||
|
|
30338ecec4 | ||
|
|
9a37defed3 | ||
|
|
c83a057996 | ||
|
|
b7588428c5 |
14
.github/workflows/docker-image.yml
vendored
14
.github/workflows/docker-image.yml
vendored
@@ -16,6 +16,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Refresh models catalog
|
||||
run: |
|
||||
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to DockerHub
|
||||
@@ -25,7 +29,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Build and push (amd64)
|
||||
@@ -47,6 +51,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Refresh models catalog
|
||||
run: |
|
||||
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to DockerHub
|
||||
@@ -56,7 +64,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Build and push (arm64)
|
||||
@@ -90,7 +98,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Create and push multi-arch manifests
|
||||
|
||||
4
.github/workflows/pr-test-build.yml
vendored
4
.github/workflows/pr-test-build.yml
vendored
@@ -12,6 +12,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Refresh models catalog
|
||||
run: |
|
||||
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
|
||||
9
.github/workflows/release.yaml
vendored
9
.github/workflows/release.yaml
vendored
@@ -16,6 +16,10 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Refresh models catalog
|
||||
run: |
|
||||
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
@@ -23,15 +27,14 @@ jobs:
|
||||
cache: true
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
VERSION=$(git describe --tags --always --dirty)
|
||||
echo "VERSION=${VERSION}" >> $GITHUB_ENV
|
||||
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
distribution: goreleaser
|
||||
version: latest
|
||||
args: release --clean
|
||||
args: release --clean --skip=validate
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
VERSION: ${{ env.VERSION }}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
version: 2
|
||||
|
||||
builds:
|
||||
- id: "cli-proxy-api-plus"
|
||||
env:
|
||||
|
||||
117
README.md
117
README.md
@@ -8,123 +8,6 @@ All third-party provider support is maintained by community contributors; CLIPro
|
||||
|
||||
The Plus release stays in lockstep with the mainline features.
|
||||
|
||||
## Differences from the Mainline
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
## New Features (Plus Enhanced)
|
||||
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
|
||||
## Kiro Authentication
|
||||
|
||||
### CLI Login
|
||||
|
||||
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
|
||||
**AWS Builder ID** (recommended):
|
||||
|
||||
```bash
|
||||
# Device code flow
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# Authorization code flow
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**Import token from Kiro IDE:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
To get a token from Kiro IDE:
|
||||
|
||||
1. Open Kiro IDE and login with Google (or GitHub)
|
||||
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
||||
3. Run: `./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# Specify region
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**Additional flags:**
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--no-browser` | Don't open browser automatically, print URL instead |
|
||||
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
||||
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
||||
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
||||
|
||||
### Web-based OAuth Login
|
||||
|
||||
Access the Kiro OAuth web interface at:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
||||
- AWS Builder ID login
|
||||
- AWS Identity Center (IDC) login
|
||||
- Token import from Kiro IDE
|
||||
|
||||
## Quick Deployment with Docker
|
||||
|
||||
### One-Command Deployment
|
||||
|
||||
```bash
|
||||
# Create deployment directory
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# Create docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# Download example config
|
||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# Pull and start
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Edit `config.yaml` before starting:
|
||||
|
||||
```yaml
|
||||
# Basic configuration example
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# Add your provider configurations here
|
||||
```
|
||||
|
||||
### Update to Latest Version
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||
|
||||
121
README_CN.md
121
README_CN.md
@@ -6,125 +6,6 @@
|
||||
|
||||
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
||||
|
||||
该 Plus 版本的主线功能与主线功能强制同步。
|
||||
|
||||
## 与主线版本版本差异
|
||||
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
|
||||
## 新增功能 (Plus 增强版)
|
||||
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。
|
||||
|
||||
智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
||||
|
||||
### 命令行登录
|
||||
|
||||
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
||||
|
||||
**AWS Builder ID**(推荐):
|
||||
|
||||
```bash
|
||||
# 设备码流程
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# 授权码流程
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**从 Kiro IDE 导入令牌:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
获取令牌步骤:
|
||||
|
||||
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
||||
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
||||
3. 运行:`./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# 指定区域
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**附加参数:**
|
||||
|
||||
| 参数 | 说明 |
|
||||
|------|------|
|
||||
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
||||
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
||||
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
||||
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
||||
|
||||
### 网页端 OAuth 登录
|
||||
|
||||
访问 Kiro OAuth 网页认证界面:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
||||
- AWS Builder ID 登录
|
||||
- AWS Identity Center (IDC) 登录
|
||||
- 从 Kiro IDE 导入令牌
|
||||
|
||||
## Docker 快速部署
|
||||
|
||||
### 一键部署
|
||||
|
||||
```bash
|
||||
# 创建部署目录
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# 创建 docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# 下载示例配置
|
||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# 拉取并启动
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### 配置说明
|
||||
|
||||
启动前请编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# 基本配置示例
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# 在此添加你的供应商配置
|
||||
```
|
||||
|
||||
### 更新到最新版本
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## 贡献
|
||||
|
||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||
@@ -133,4 +14,4 @@ docker compose pull && docker compose up -d
|
||||
|
||||
## 许可证
|
||||
|
||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||
|
||||
275
cmd/fetch_antigravity_models/main.go
Normal file
275
cmd/fetch_antigravity_models/main.go
Normal file
@@ -0,0 +1,275 @@
|
||||
// Command fetch_antigravity_models connects to the Antigravity API using the
|
||||
// stored auth credentials and saves the dynamically fetched model list to a
|
||||
// JSON file for inspection or offline use.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// go run ./cmd/fetch_antigravity_models [flags]
|
||||
//
|
||||
// Flags:
|
||||
//
|
||||
// --auths-dir <path> Directory containing auth JSON files (default: "auths")
|
||||
// --output <path> Output JSON file path (default: "antigravity_models.json")
|
||||
// --pretty Pretty-print the output JSON (default: true)
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
|
||||
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logging.SetupBaseLogger()
|
||||
log.SetLevel(log.InfoLevel)
|
||||
}
|
||||
|
||||
// modelOutput wraps the fetched model list with fetch metadata.
|
||||
type modelOutput struct {
|
||||
Models []modelEntry `json:"models"`
|
||||
}
|
||||
|
||||
// modelEntry contains only the fields we want to keep for static model definitions.
|
||||
type modelEntry struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ContextLength int `json:"context_length,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var authsDir string
|
||||
var outputPath string
|
||||
var pretty bool
|
||||
|
||||
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
|
||||
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
|
||||
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
|
||||
flag.Parse()
|
||||
|
||||
// Resolve relative paths against the working directory.
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if !filepath.IsAbs(authsDir) {
|
||||
authsDir = filepath.Join(wd, authsDir)
|
||||
}
|
||||
if !filepath.IsAbs(outputPath) {
|
||||
outputPath = filepath.Join(wd, outputPath)
|
||||
}
|
||||
|
||||
fmt.Printf("Scanning auth files in: %s\n", authsDir)
|
||||
|
||||
// Load all auth records from the directory.
|
||||
fileStore := sdkauth.NewFileTokenStore()
|
||||
fileStore.SetBaseDir(authsDir)
|
||||
|
||||
ctx := context.Background()
|
||||
auths, err := fileStore.List(ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if len(auths) == 0 {
|
||||
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Find the first enabled antigravity auth.
|
||||
var chosen *coreauth.Auth
|
||||
for _, a := range auths {
|
||||
if a == nil || a.Disabled {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
|
||||
chosen = a
|
||||
break
|
||||
}
|
||||
}
|
||||
if chosen == nil {
|
||||
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
|
||||
|
||||
// Fetch models from the upstream Antigravity API.
|
||||
fmt.Println("Fetching Antigravity model list from upstream...")
|
||||
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
models := fetchModels(fetchCtx, chosen)
|
||||
if len(models) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
|
||||
} else {
|
||||
fmt.Printf("Fetched %d models.\n", len(models))
|
||||
}
|
||||
|
||||
// Build the output payload.
|
||||
out := modelOutput{
|
||||
Models: models,
|
||||
}
|
||||
|
||||
// Marshal to JSON.
|
||||
var raw []byte
|
||||
if pretty {
|
||||
raw, err = json.MarshalIndent(out, "", " ")
|
||||
} else {
|
||||
raw, err = json.Marshal(out)
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Model list saved to: %s\n", outputPath)
|
||||
}
|
||||
|
||||
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
|
||||
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||
if accessToken == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
|
||||
return nil
|
||||
}
|
||||
|
||||
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
|
||||
|
||||
for _, baseURL := range baseURLs {
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
|
||||
var payload []byte
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||
}
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
payload = []byte(`{}`)
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
|
||||
if errReq != nil {
|
||||
continue
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
|
||||
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
|
||||
httpClient.Transport = transport
|
||||
}
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
httpResp.Body.Close()
|
||||
if errRead != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
continue
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
var models []modelEntry
|
||||
|
||||
for originalName, modelData := range result.Map() {
|
||||
modelID := strings.TrimSpace(originalName)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
// Skip internal/experimental models
|
||||
switch modelID {
|
||||
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||
continue
|
||||
}
|
||||
|
||||
displayName := modelData.Get("displayName").String()
|
||||
if displayName == "" {
|
||||
displayName = modelID
|
||||
}
|
||||
|
||||
entry := modelEntry{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
OwnedBy: "antigravity",
|
||||
Type: "antigravity",
|
||||
DisplayName: displayName,
|
||||
Name: modelID,
|
||||
Description: displayName,
|
||||
}
|
||||
|
||||
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||
entry.ContextLength = int(maxTok)
|
||||
}
|
||||
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||
entry.MaxCompletionTokens = int(maxOut)
|
||||
}
|
||||
|
||||
models = append(models, entry)
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func metaStringValue(m map[string]interface{}, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||
@@ -78,6 +79,8 @@ func main() {
|
||||
var kiloLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var gitlabLogin bool
|
||||
var gitlabTokenLogin bool
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
@@ -100,6 +103,7 @@ func main() {
|
||||
var standalone bool
|
||||
var noIncognito bool
|
||||
var useIncognito bool
|
||||
var localModel bool
|
||||
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
@@ -110,6 +114,8 @@ func main() {
|
||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
|
||||
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||
@@ -132,6 +138,7 @@ func main() {
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
out := flag.CommandLine.Output()
|
||||
@@ -526,6 +533,10 @@ func main() {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else if gitlabLogin {
|
||||
cmd.DoGitLabLogin(cfg, options)
|
||||
} else if gitlabTokenLogin {
|
||||
cmd.DoGitLabTokenLogin(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
@@ -569,10 +580,16 @@ func main() {
|
||||
cmd.WaitForCloudDeploy()
|
||||
return
|
||||
}
|
||||
if localModel && (!tuiMode || standalone) {
|
||||
log.Info("Local model mode: using embedded model catalog, remote model updates disabled")
|
||||
}
|
||||
if tuiMode {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
if !localModel {
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
}
|
||||
hook := tui.NewLogHook(2000)
|
||||
hook.SetFormatter(&logging.LogFormatter{})
|
||||
log.AddHook(hook)
|
||||
@@ -643,15 +660,18 @@ func main() {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
if !localModel {
|
||||
registry.StartModelsUpdater(context.Background())
|
||||
}
|
||||
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +68,8 @@ error-logs-max-files: 10
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ''
|
||||
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
|
||||
proxy-url: ""
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
@@ -115,6 +116,7 @@ nonstream-keepalive-interval: 0
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# models:
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||
@@ -133,6 +135,7 @@ nonstream-keepalive-interval: 0
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# models:
|
||||
# - name: "gpt-5-codex" # upstream model name
|
||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||
@@ -151,6 +154,7 @@ nonstream-keepalive-interval: 0
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
@@ -178,6 +182,14 @@ nonstream-keepalive-interval: 0
|
||||
# runtime-version: "v24.3.0"
|
||||
# timeout: "600"
|
||||
|
||||
# Default headers for Codex OAuth model requests.
|
||||
# These are used only for file-backed/OAuth Codex requests when the client
|
||||
# does not send the header. `user-agent` applies to HTTP and websocket requests;
|
||||
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
|
||||
# codex-header-defaults:
|
||||
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
|
||||
# beta-features: "multi_agent"
|
||||
|
||||
# Kiro (AWS CodeWhisperer) configuration
|
||||
# Note: Kiro API currently only operates in us-east-1 region
|
||||
#kiro:
|
||||
@@ -215,17 +227,30 @@ nonstream-keepalive-interval: 0
|
||||
# api-key-entries:
|
||||
# - api-key: "sk-or-v1-...b780"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||
# models: # The models supported by the provider.
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# # You may repeat the same alias to build an internal model pool.
|
||||
# # The client still sees only one alias in the model list.
|
||||
# # Requests to that alias will round-robin across the upstream names below,
|
||||
# # and if the chosen upstream fails before producing output, the request will
|
||||
# # continue with the next upstream model in the same alias pool.
|
||||
# - name: "qwen3.5-plus"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "glm-5"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "claude-opus-4.66"
|
||||
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# Vertex API keys (Vertex-compatible endpoints, base-url is optional)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# base-url: "https://example.com/api" # optional, e.g. https://zenmux.ai/api; falls back to Google Vertex when omitted
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# models: # optional: map aliases to upstream model names
|
||||
|
||||
115
docs/gitlab-duo.md
Normal file
115
docs/gitlab-duo.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# GitLab Duo guide
|
||||
|
||||
CLIProxyAPI can now use GitLab Duo as a first-class provider instead of treating it as a plain text wrapper.
|
||||
|
||||
It supports:
|
||||
|
||||
- OAuth login
|
||||
- personal access token login
|
||||
- automatic refresh of GitLab `direct_access` metadata
|
||||
- dynamic model discovery from GitLab metadata
|
||||
- native GitLab AI gateway routing for Anthropic and OpenAI/Codex managed models
|
||||
- Claude-compatible and OpenAI-compatible downstream APIs
|
||||
|
||||
## What this means
|
||||
|
||||
If GitLab Duo returns an Anthropic-managed model, CLIProxyAPI routes requests through the GitLab AI gateway Anthropic proxy and uses the existing Claude executor path.
|
||||
|
||||
If GitLab Duo returns an OpenAI-managed model, CLIProxyAPI routes requests through the GitLab AI gateway OpenAI proxy and uses the existing Codex/OpenAI executor path.
|
||||
|
||||
That gives GitLab Duo much closer runtime behavior to the built-in `codex` provider:
|
||||
|
||||
- Claude-compatible clients can use GitLab Duo models through `/v1/messages`
|
||||
- OpenAI-compatible clients can use GitLab Duo models through `/v1/chat/completions`
|
||||
- OpenAI Responses clients can use GitLab Duo models through `/v1/responses`
|
||||
|
||||
The model list is not hardcoded. CLIProxyAPI reads the current model metadata from GitLab `direct_access` and registers:
|
||||
|
||||
- a stable alias: `gitlab-duo`
|
||||
- any discovered managed model names, such as `claude-sonnet-4-5` or `gpt-5-codex`
|
||||
|
||||
## Login
|
||||
|
||||
OAuth login:
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI -gitlab-login
|
||||
```
|
||||
|
||||
PAT login:
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI -gitlab-token-login
|
||||
```
|
||||
|
||||
You can also provide inputs through environment variables:
|
||||
|
||||
```bash
|
||||
export GITLAB_BASE_URL=https://gitlab.com
|
||||
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- OAuth requires a GitLab OAuth application.
|
||||
- PAT login requires a personal access token that can call the GitLab APIs used by Duo. In practice, `api` scope is the safe baseline.
|
||||
- Self-managed GitLab instances are supported through `GITLAB_BASE_URL`.
|
||||
|
||||
## Using the models
|
||||
|
||||
After login, start CLIProxyAPI normally and point your client at the local proxy.
|
||||
|
||||
You can select:
|
||||
|
||||
- `gitlab-duo` to use the current Duo-managed model for that account
|
||||
- the discovered provider model name if you want to pin it explicitly
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/v1/models
|
||||
```
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gitlab-duo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
If the GitLab account is currently mapped to an Anthropic model, Claude-compatible clients can use the same account through the Claude handler path. If the account is currently mapped to an OpenAI/Codex model, OpenAI-compatible clients can use `/v1/chat/completions` or `/v1/responses`.
|
||||
|
||||
## How model freshness works
|
||||
|
||||
CLIProxyAPI does not ship a fixed GitLab Duo model catalog.
|
||||
|
||||
Instead, it refreshes GitLab `direct_access` metadata and uses the returned `model_details` and any discovered model list entries to keep the local registry aligned with the current GitLab-managed model assignment.
|
||||
|
||||
This matches GitLab's current public contract better than hardcoding model names.
|
||||
|
||||
## Current scope
|
||||
|
||||
The GitLab Duo provider now has:
|
||||
|
||||
- OAuth and PAT auth flows
|
||||
- runtime refresh of Duo gateway credentials
|
||||
- native Anthropic gateway routing
|
||||
- native OpenAI/Codex gateway routing
|
||||
- handler-level smoke tests for Claude-compatible and OpenAI-compatible paths
|
||||
|
||||
Still out of scope today:
|
||||
|
||||
- websocket or session-specific parity beyond the current HTTP APIs
|
||||
- GitLab-specific IDE features that are not exposed through the public gateway contract
|
||||
|
||||
## References
|
||||
|
||||
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||
- GitLab Agent Assistant and managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||
- GitLab Duo model selection: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||
115
docs/gitlab-duo_CN.md
Normal file
115
docs/gitlab-duo_CN.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# GitLab Duo 使用说明
|
||||
|
||||
CLIProxyAPI 现在可以把 GitLab Duo 当作一等 Provider 来使用,而不是仅仅把它当成简单的文本补全封装。
|
||||
|
||||
当前支持:
|
||||
|
||||
- OAuth 登录
|
||||
- personal access token 登录
|
||||
- 自动刷新 GitLab `direct_access` 元数据
|
||||
- 根据 GitLab 返回的元数据动态发现模型
|
||||
- 针对 Anthropic 和 OpenAI/Codex 托管模型的 GitLab AI gateway 原生路由
|
||||
- Claude 兼容与 OpenAI 兼容下游 API
|
||||
|
||||
## 这意味着什么
|
||||
|
||||
如果 GitLab Duo 返回的是 Anthropic 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 Anthropic 代理转发,并复用现有的 Claude executor 路径。
|
||||
|
||||
如果 GitLab Duo 返回的是 OpenAI 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 OpenAI 代理转发,并复用现有的 Codex/OpenAI executor 路径。
|
||||
|
||||
这让 GitLab Duo 的运行时行为更接近内置的 `codex` Provider:
|
||||
|
||||
- Claude 兼容客户端可以通过 `/v1/messages` 使用 GitLab Duo 模型
|
||||
- OpenAI 兼容客户端可以通过 `/v1/chat/completions` 使用 GitLab Duo 模型
|
||||
- OpenAI Responses 客户端可以通过 `/v1/responses` 使用 GitLab Duo 模型
|
||||
|
||||
模型列表不是硬编码的。CLIProxyAPI 会从 GitLab `direct_access` 中读取当前模型元数据,并注册:
|
||||
|
||||
- 一个稳定别名:`gitlab-duo`
|
||||
- GitLab 当前发现到的托管模型名,例如 `claude-sonnet-4-5` 或 `gpt-5-codex`
|
||||
|
||||
## 登录
|
||||
|
||||
OAuth 登录:
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI -gitlab-login
|
||||
```
|
||||
|
||||
PAT 登录:
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI -gitlab-token-login
|
||||
```
|
||||
|
||||
也可以通过环境变量提供输入:
|
||||
|
||||
```bash
|
||||
export GITLAB_BASE_URL=https://gitlab.com
|
||||
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- OAuth 方式需要一个 GitLab OAuth application。
|
||||
- PAT 登录需要一个能够调用 GitLab Duo 相关 API 的 personal access token。实践上,`api` scope 是最稳妥的基线。
|
||||
- 自建 GitLab 实例可以通过 `GITLAB_BASE_URL` 接入。
|
||||
|
||||
## 如何使用模型
|
||||
|
||||
登录完成后,正常启动 CLIProxyAPI,并让客户端连接到本地代理。
|
||||
|
||||
你可以选择:
|
||||
|
||||
- `gitlab-duo`,始终使用该账号当前的 Duo 托管模型
|
||||
- GitLab 当前发现到的 provider 模型名,如果你想显式固定模型
|
||||
|
||||
示例:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/v1/models
|
||||
```
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gitlab-duo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
如果该 GitLab 账号当前绑定的是 Anthropic 模型,Claude 兼容客户端可以通过 Claude handler 路径直接使用它。如果当前绑定的是 OpenAI/Codex 模型,OpenAI 兼容客户端可以通过 `/v1/chat/completions` 或 `/v1/responses` 使用它。
|
||||
|
||||
## 模型如何保持最新
|
||||
|
||||
CLIProxyAPI 不内置固定的 GitLab Duo 模型清单。
|
||||
|
||||
它会刷新 GitLab `direct_access` 元数据,并使用返回的 `model_details` 以及可能存在的模型列表字段,让本地 registry 尽量与 GitLab 当前分配的托管模型保持一致。
|
||||
|
||||
这比硬编码模型名更符合 GitLab 当前公开 API 的实际契约。
|
||||
|
||||
## 当前覆盖范围
|
||||
|
||||
GitLab Duo Provider 目前已经具备:
|
||||
|
||||
- OAuth 和 PAT 登录流程
|
||||
- Duo gateway 凭据的运行时刷新
|
||||
- Anthropic gateway 原生路由
|
||||
- OpenAI/Codex gateway 原生路由
|
||||
- Claude 兼容和 OpenAI 兼容路径的 handler 级 smoke 测试
|
||||
|
||||
当前仍未覆盖:
|
||||
|
||||
- websocket 或 session 级别的完全对齐
|
||||
- GitLab 公开 gateway 契约之外的 IDE 专有能力
|
||||
|
||||
## 参考资料
|
||||
|
||||
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||
- GitLab Agent Assistant 与 managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||
- GitLab Duo 模型选择: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||
@@ -52,11 +52,11 @@ func init() {
|
||||
sdktr.Register(fOpenAI, fMyProv,
|
||||
func(model string, raw []byte, stream bool) []byte { return raw },
|
||||
sdktr.ResponseTransform{
|
||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
|
||||
return []string{string(raw)}
|
||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) [][]byte {
|
||||
return [][]byte{raw}
|
||||
},
|
||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
|
||||
return string(raw)
|
||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []byte {
|
||||
return raw
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
278
gitlab-duo-codex-parity-plan.md
Normal file
278
gitlab-duo-codex-parity-plan.md
Normal file
@@ -0,0 +1,278 @@
|
||||
# Plan: GitLab Duo Codex Parity
|
||||
|
||||
**Generated**: 2026-03-10
|
||||
**Estimated Complexity**: High
|
||||
|
||||
## Overview
|
||||
Bring GitLab Duo support from the current "auth + basic executor" stage to the same practical level as `codex` inside `CLIProxyAPI`: a user logs in once, points external clients such as Claude Code at `CLIProxyAPI`, selects GitLab Duo-backed models, and gets stable streaming, multi-turn behavior, tool calling compatibility, and predictable model routing without manual provider-specific workarounds.
|
||||
|
||||
The core architectural shift is to stop treating GitLab Duo as only two REST wrappers (`/api/v4/chat/completions` and `/api/v4/code_suggestions/completions`) and instead use GitLab's `direct_access` contract as the primary runtime entrypoint wherever possible. Official GitLab docs confirm that `direct_access` returns AI gateway connection details, headers, token, and expiry; that contract is the closest path to codex-like provider behavior.
|
||||
|
||||
## Prerequisites
|
||||
- Official GitLab Duo API references confirmed during implementation:
|
||||
- `POST /api/v4/code_suggestions/direct_access`
|
||||
- `POST /api/v4/code_suggestions/completions`
|
||||
- `POST /api/v4/chat/completions`
|
||||
- Access to at least one real GitLab Duo account for manual verification.
|
||||
- One downstream client target for acceptance testing:
|
||||
- Claude Code against Claude-compatible endpoint
|
||||
- OpenAI-compatible client against `/v1/chat/completions` and `/v1/responses`
|
||||
- Existing PR branch as starting point:
|
||||
- `feat/gitlab-duo-auth`
|
||||
- PR [#2028](https://github.com/router-for-me/CLIProxyAPI/pull/2028)
|
||||
|
||||
## Definition Of Done
|
||||
- GitLab Duo models can be used via `CLIProxyAPI` from the same client surfaces that already work for `codex`.
|
||||
- Upstream streaming is real passthrough or faithful chunked forwarding, not synthetic whole-response replay.
|
||||
- Tool/function calling survives translation layers without dropping fields or corrupting names.
|
||||
- Multi-turn and session semantics are stable across `chat/completions`, `responses`, and Claude-compatible routes.
|
||||
- Model exposure stays current from GitLab metadata or gateway discovery without hardcoded stale model tables.
|
||||
- `go test ./...` stays green and at least one real manual end-to-end client flow is documented.
|
||||
|
||||
## Sprint 1: Contract And Gap Closure
|
||||
**Goal**: Replace assumptions with a hard compatibility contract between current `codex` behavior and what GitLab Duo can actually support.
|
||||
|
||||
**Demo/Validation**:
|
||||
- Written matrix showing `codex` features vs current GitLab Duo behavior.
|
||||
- One checked-in developer note or test fixture for real GitLab Duo payload examples.
|
||||
|
||||
### Task 1.1: Freeze Codex Parity Checklist
|
||||
- **Location**: [internal/runtime/executor/codex_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_executor.go), [internal/runtime/executor/codex_websockets_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_websockets_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go)
|
||||
- **Description**: Produce a concrete feature matrix for `codex`: HTTP execute, SSE execute, `/v1/responses`, websocket downstream path, tool calling, request IDs, session close semantics, and model registration behavior.
|
||||
- **Dependencies**: None
|
||||
- **Acceptance Criteria**:
|
||||
- A checklist exists in repo docs or issue notes.
|
||||
- Each capability is marked `required`, `optional`, or `not possible` for GitLab Duo.
|
||||
- **Validation**:
|
||||
- Review against current `codex` code paths.
|
||||
|
||||
### Task 1.2: Lock GitLab Duo Runtime Contract
|
||||
- **Location**: [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||
- **Description**: Validate the exact upstream contract we can rely on:
|
||||
- `direct_access` fields and refresh cadence
|
||||
- whether AI gateway path is usable directly
|
||||
- when `chat/completions` is available vs when fallback is required
|
||||
- what streaming shape is returned by `code_suggestions/completions?stream=true`
|
||||
- **Dependencies**: Task 1.1
|
||||
- **Acceptance Criteria**:
|
||||
- GitLab transport decision is explicit: `gateway-first`, `REST-first`, or `hybrid`.
|
||||
- Unknown areas are isolated behind feature flags, not spread across executor logic.
|
||||
- **Validation**:
|
||||
- Official docs + captured real responses from a Duo account.
|
||||
|
||||
### Task 1.3: Define Client-Facing Compatibility Targets
|
||||
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md), [gitlab-duo-codex-parity-plan.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/gitlab-duo-codex-parity-plan.md)
|
||||
- **Description**: Define exactly which external flows must work to call GitLab Duo support "like codex".
|
||||
- **Dependencies**: Task 1.2
|
||||
- **Acceptance Criteria**:
|
||||
- Required surfaces are listed:
|
||||
- Claude-compatible route
|
||||
- OpenAI `chat/completions`
|
||||
- OpenAI `responses`
|
||||
- optional downstream websocket path
|
||||
- Non-goals are explicit if GitLab upstream cannot support them.
|
||||
- **Validation**:
|
||||
- Maintainer review of stated scope.
|
||||
|
||||
## Sprint 2: Primary Transport Parity
|
||||
**Goal**: Move GitLab Duo execution onto a transport that supports codex-like runtime behavior.
|
||||
|
||||
**Demo/Validation**:
|
||||
- A GitLab Duo model works over real streaming through `/v1/chat/completions`.
|
||||
- No synthetic "collect full body then fake stream" path remains on the primary flow.
|
||||
|
||||
### Task 2.1: Refactor GitLab Executor Into Strategy Layers
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||
- **Description**: Split current executor into explicit strategies:
|
||||
- auth refresh/direct access refresh
|
||||
- gateway transport
|
||||
- GitLab REST fallback transport
|
||||
- downstream translation helpers
|
||||
- **Dependencies**: Sprint 1
|
||||
- **Acceptance Criteria**:
|
||||
- Executor no longer mixes discovery, refresh, fallback selection, and response synthesis in one path.
|
||||
- Transport choice is testable in isolation.
|
||||
- **Validation**:
|
||||
- Unit tests for strategy selection and fallback boundaries.
|
||||
|
||||
### Task 2.2: Implement Real Streaming Path
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go)
|
||||
- **Description**: Replace synthetic streaming with true upstream incremental forwarding:
|
||||
- use gateway stream if available
|
||||
- otherwise consume GitLab Code Suggestions streaming response and map chunks incrementally
|
||||
- **Dependencies**: Task 2.1
|
||||
- **Acceptance Criteria**:
|
||||
- `ExecuteStream` emits chunks before upstream completion.
|
||||
- error handling preserves status and early failure semantics.
|
||||
- **Validation**:
|
||||
- tests with chunked upstream server
|
||||
- manual curl check against `/v1/chat/completions` with `stream=true`
|
||||
|
||||
### Task 2.3: Preserve Upstream Auth And Headers Correctly
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go)
|
||||
- **Description**: Use `direct_access` connection details as first-class transport state:
|
||||
- gateway token
|
||||
- expiry
|
||||
- mandatory forwarded headers
|
||||
- model metadata
|
||||
- **Dependencies**: Task 2.1
|
||||
- **Acceptance Criteria**:
|
||||
- executor stops ignoring gateway headers/token when transport requires them
|
||||
- refresh logic never over-fetches `direct_access`
|
||||
- **Validation**:
|
||||
- tests verifying propagated headers and refresh interval behavior
|
||||
|
||||
## Sprint 3: Request/Response Semantics Parity
|
||||
**Goal**: Make GitLab Duo behave correctly under the same request shapes that current `codex` consumers send.
|
||||
|
||||
**Demo/Validation**:
|
||||
- OpenAI and Claude-compatible clients can do non-streaming and streaming conversations without losing structure.
|
||||
|
||||
### Task 3.1: Normalize Multi-Turn Message Mapping
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/translator](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/translator)
|
||||
- **Description**: Replace the current "flatten prompt into one instruction" behavior with stable multi-turn mapping:
|
||||
- preserve system context
|
||||
- preserve user/assistant ordering
|
||||
- maintain bounded context truncation
|
||||
- **Dependencies**: Sprint 2
|
||||
- **Acceptance Criteria**:
|
||||
- multi-turn requests are not collapsed into a lossy single string unless fallback mode explicitly requires it
|
||||
- truncation policy is deterministic and tested
|
||||
- **Validation**:
|
||||
- golden tests for request mapping
|
||||
|
||||
### Task 3.2: Tool Calling Compatibility Layer
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go)
|
||||
- **Description**: Decide and implement one of two paths:
|
||||
- native pass-through if GitLab gateway supports tool/function structures
|
||||
- strict downgrade path with explicit unsupported errors instead of silent field loss
|
||||
- **Dependencies**: Task 3.1
|
||||
- **Acceptance Criteria**:
|
||||
- tool-related fields are either preserved correctly or rejected explicitly
|
||||
- no silent corruption of tool names, tool calls, or tool results
|
||||
- **Validation**:
|
||||
- table-driven tests for tool payloads
|
||||
- one manual client scenario using tools
|
||||
|
||||
### Task 3.3: Token Counting And Usage Reporting Fidelity
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/usage_helpers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/usage_helpers.go)
|
||||
- **Description**: Improve token/usage reporting so GitLab models behave like first-class providers in logs and scheduling.
|
||||
- **Dependencies**: Sprint 2
|
||||
- **Acceptance Criteria**:
|
||||
- `CountTokens` uses the closest supported estimation path
|
||||
- usage logging distinguishes prompt vs completion when possible
|
||||
- **Validation**:
|
||||
- unit tests for token estimation outputs
|
||||
|
||||
## Sprint 4: Responses And Session Parity
|
||||
**Goal**: Reach codex-level support for OpenAI Responses clients and long-lived sessions where GitLab upstream permits it.
|
||||
|
||||
**Demo/Validation**:
|
||||
- `/v1/responses` works with GitLab Duo in a realistic client flow.
|
||||
- If websocket parity is not possible, the code explicitly declines it and keeps HTTP paths stable.
|
||||
|
||||
### Task 4.1: Make GitLab Compatible With `/v1/responses`
|
||||
- **Location**: [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||
- **Description**: Ensure GitLab transport can safely back the Responses API path, including compact responses if applicable.
|
||||
- **Dependencies**: Sprint 3
|
||||
- **Acceptance Criteria**:
|
||||
- GitLab Duo can be selected behind `/v1/responses`
|
||||
- response IDs and follow-up semantics are defined
|
||||
- **Validation**:
|
||||
- handler tests analogous to codex/openai responses tests
|
||||
|
||||
### Task 4.2: Evaluate Downstream Websocket Parity
|
||||
- **Location**: [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||
- **Description**: Decide whether GitLab Duo can support downstream websocket sessions like codex:
|
||||
- if yes, add session-aware execution path
|
||||
- if no, mark GitLab auth as websocket-ineligible and keep HTTP routes first-class
|
||||
- **Dependencies**: Task 4.1
|
||||
- **Acceptance Criteria**:
|
||||
- websocket behavior is explicit, not accidental
|
||||
- no route claims websocket support when the upstream cannot honor it
|
||||
- **Validation**:
|
||||
- websocket handler tests or explicit capability tests
|
||||
|
||||
### Task 4.3: Add Session Cleanup And Failure Recovery Semantics
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/cliproxy/auth/conductor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/auth/conductor.go)
|
||||
- **Description**: Add codex-like session cleanup, retry boundaries, and model suspension/resume behavior for GitLab failures and quota events.
|
||||
- **Dependencies**: Sprint 2
|
||||
- **Acceptance Criteria**:
|
||||
- auth/model cooldown behavior is predictable on GitLab 4xx/5xx/quota responses
|
||||
- executor cleans up per-session resources if any are introduced
|
||||
- **Validation**:
|
||||
- tests for quota and retry behavior
|
||||
|
||||
## Sprint 5: Client UX, Model UX, And Manual E2E
|
||||
**Goal**: Make GitLab Duo feel like a normal built-in provider to operators and downstream clients.
|
||||
|
||||
**Demo/Validation**:
|
||||
- A documented setup exists for "login once, point Claude Code at CLIProxyAPI, use GitLab Duo-backed model".
|
||||
|
||||
### Task 5.1: Model Alias And Provider UX Cleanup
|
||||
- **Location**: [sdk/cliproxy/service.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/service.go), [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
|
||||
- **Description**: Normalize what users see:
|
||||
- stable alias such as `gitlab-duo`
|
||||
- discovered upstream model names
|
||||
- optional prefix behavior
|
||||
- account labels that clearly distinguish OAuth vs PAT
|
||||
- **Dependencies**: Sprint 3
|
||||
- **Acceptance Criteria**:
|
||||
- users can select a stable GitLab alias even when upstream model changes
|
||||
- dynamic model discovery does not cause confusing model churn
|
||||
- **Validation**:
|
||||
- registry tests and manual `/v1/models` inspection
|
||||
|
||||
### Task 5.2: Add Real End-To-End Acceptance Tests
|
||||
- **Location**: [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go), [sdk/api/handlers/openai](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai)
|
||||
- **Description**: Add higher-level tests covering the actual proxy surfaces:
|
||||
- OpenAI `chat/completions`
|
||||
- OpenAI `responses`
|
||||
- Claude-compatible request path if GitLab is routed there
|
||||
- **Dependencies**: Sprint 4
|
||||
- **Acceptance Criteria**:
|
||||
- tests fail if streaming regresses into synthetic buffering again
|
||||
- tests cover at least one tool-related request and one multi-turn request
|
||||
- **Validation**:
|
||||
- `go test ./...`
|
||||
|
||||
### Task 5.3: Publish Operator Documentation
|
||||
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
|
||||
- **Description**: Document:
|
||||
- OAuth setup requirements
|
||||
- PAT requirements
|
||||
- current capability matrix
|
||||
- known limitations if websocket/tool parity is partial
|
||||
- **Dependencies**: Sprint 5.1
|
||||
- **Acceptance Criteria**:
|
||||
- setup instructions are enough for a new user to reproduce the GitLab Duo flow
|
||||
- limitations are explicit
|
||||
- **Validation**:
|
||||
- dry-run docs review from a clean environment
|
||||
|
||||
## Testing Strategy
|
||||
- Keep `go test ./...` green after every committable task.
|
||||
- Add table-driven tests first for request mapping, refresh behavior, and dynamic model registration.
|
||||
- Add transport tests with `httptest.Server` for:
|
||||
- real chunked streaming
|
||||
- header propagation from `direct_access`
|
||||
- upstream fallback rules
|
||||
- Add at least one manual acceptance checklist:
|
||||
- login via OAuth
|
||||
- login via PAT
|
||||
- list models
|
||||
- run one streaming prompt via OpenAI route
|
||||
- run one prompt from the target downstream client
|
||||
|
||||
## Potential Risks & Gotchas
|
||||
- GitLab public docs expose `direct_access`, but do not fully document every possible AI gateway path. We should isolate any empirically discovered gateway assumptions behind one transport layer and feature flags.
|
||||
- `chat/completions` availability differs by GitLab offering and version. The executor must not assume it always exists.
|
||||
- Code Suggestions is completion-oriented; lossy mapping from rich chat/tool payloads will make GitLab Duo feel worse than codex unless explicitly handled.
|
||||
- Synthetic streaming is not good enough for codex parity and will cause regressions in interactive clients.
|
||||
- Dynamic model discovery can create unstable UX if the stable alias and discovered model IDs are not separated cleanly.
|
||||
- PAT auth may validate successfully while still lacking effective Duo permissions. Error reporting must surface this explicitly.
|
||||
|
||||
## Rollback Plan
|
||||
- Keep the current basic GitLab executor behind a fallback mode until the new transport path is stable.
|
||||
- If parity work destabilizes existing providers, revert only GitLab-specific executor changes and leave auth support intact.
|
||||
- Preserve the stable `gitlab-duo` alias so rollback does not break client configuration.
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -14,13 +13,12 @@ import (
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const defaultAPICallTimeout = 60 * time.Second
|
||||
@@ -725,47 +723,12 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
||||
}
|
||||
|
||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
proxyStr = strings.TrimSpace(proxyStr)
|
||||
if proxyStr == "" {
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||
if errBuild != nil {
|
||||
log.WithError(errBuild).Debug("build proxy transport failed")
|
||||
return nil
|
||||
}
|
||||
|
||||
proxyURL, errParse := url.Parse(proxyStr)
|
||||
if errParse != nil {
|
||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
||||
return nil
|
||||
}
|
||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
||||
log.Debug("proxy URL missing scheme/host")
|
||||
return nil
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
var proxyAuth *proxy.Auth
|
||||
if proxyURL.User != nil {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
||||
return nil
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
return nil
|
||||
return transport
|
||||
}
|
||||
|
||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||
|
||||
@@ -2,172 +2,112 @@ package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
type memoryAuthStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]*coreauth.Auth
|
||||
}
|
||||
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||
for _, a := range s.items {
|
||||
out = append(out, a.Clone())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
_ = ctx
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.items == nil {
|
||||
s.items = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
s.items[auth.ID] = auth.Clone()
|
||||
s.mu.Unlock()
|
||||
return auth.ID, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
delete(s.items, id)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
||||
t.Fatalf("unexpected content-type: %s", ct)
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
values, err := url.ParseQuery(string(bodyBytes))
|
||||
if err != nil {
|
||||
t.Fatalf("parse form: %v", err)
|
||||
}
|
||||
if values.Get("grant_type") != "refresh_token" {
|
||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
||||
}
|
||||
if values.Get("refresh_token") != "rt" {
|
||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
||||
}
|
||||
if values.Get("client_id") != antigravityOAuthClientID {
|
||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
||||
}
|
||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
||||
t.Fatalf("unexpected client_secret")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-token",
|
||||
"refresh_token": "rt2",
|
||||
"expires_in": int64(3600),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-test.json",
|
||||
FileName: "antigravity-test.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "rt",
|
||||
"expires_in": int64(3600),
|
||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
h := &Handler{
|
||||
cfg: &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||
},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
|
||||
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
|
||||
httpTransport, ok := transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||
}
|
||||
if httpTransport.Proxy != nil {
|
||||
t.Fatal("expected direct transport to disable proxy function")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := &Handler{
|
||||
cfg: &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||
},
|
||||
}
|
||||
|
||||
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
|
||||
httpTransport, ok := transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||
}
|
||||
|
||||
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
if errRequest != nil {
|
||||
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||
}
|
||||
|
||||
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||
if errProxy != nil {
|
||||
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||
}
|
||||
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
|
||||
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
geminiAuth := &coreauth.Auth{
|
||||
ID: "gemini:apikey:123",
|
||||
Provider: "gemini",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "shared-key",
|
||||
},
|
||||
}
|
||||
compatAuth := &coreauth.Auth{
|
||||
ID: "openai-compatibility:bohe:456",
|
||||
Provider: "bohe",
|
||||
Label: "bohe",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "shared-key",
|
||||
"compat_name": "bohe",
|
||||
"provider_key": "bohe",
|
||||
},
|
||||
}
|
||||
|
||||
if _, errRegister := manager.Register(context.Background(), geminiAuth); errRegister != nil {
|
||||
t.Fatalf("register gemini auth: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), compatAuth); errRegister != nil {
|
||||
t.Fatalf("register compat auth: %v", errRegister)
|
||||
}
|
||||
|
||||
geminiIndex := geminiAuth.EnsureIndex()
|
||||
compatIndex := compatAuth.EnsureIndex()
|
||||
if geminiIndex == compatIndex {
|
||||
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
|
||||
}
|
||||
|
||||
h := &Handler{authManager: manager}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
|
||||
gotGemini := h.authByIndex(geminiIndex)
|
||||
if gotGemini == nil {
|
||||
t.Fatal("expected gemini auth by index")
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token, got %q", token)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
||||
if gotGemini.ID != geminiAuth.ID {
|
||||
t.Fatalf("authByIndex(gemini) returned %q, want %q", gotGemini.ID, geminiAuth.ID)
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID(auth.ID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth in manager after update")
|
||||
gotCompat := h.authByIndex(compatIndex)
|
||||
if gotCompat == nil {
|
||||
t.Fatal("expected compat auth by index")
|
||||
}
|
||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-valid.json",
|
||||
FileName: "antigravity-valid.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "ok-token",
|
||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
h := &Handler{}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
}
|
||||
if token != "ok-token" {
|
||||
t.Fatalf("expected existing token, got %q", token)
|
||||
}
|
||||
if callCount != 0 {
|
||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
||||
if gotCompat.ID != compatAuth.ID {
|
||||
t.Fatalf("authByIndex(compat) returned %q, want %q", gotCompat.ID, compatAuth.ID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
@@ -54,6 +55,8 @@ const (
|
||||
codexCallbackPort = 1455
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
gitLabLoginModeOAuth = "oauth"
|
||||
gitLabLoginModePAT = "pat"
|
||||
)
|
||||
|
||||
type callbackForwarder struct {
|
||||
@@ -338,6 +341,21 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
||||
emailValue := gjson.GetBytes(data, "email").String()
|
||||
fileData["type"] = typeValue
|
||||
fileData["email"] = emailValue
|
||||
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
|
||||
switch pv.Type {
|
||||
case gjson.Number:
|
||||
fileData["priority"] = int(pv.Int())
|
||||
case gjson.String:
|
||||
if parsed, errAtoi := strconv.Atoi(strings.TrimSpace(pv.String())); errAtoi == nil {
|
||||
fileData["priority"] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if nv := gjson.GetBytes(data, "note"); nv.Exists() && nv.Type == gjson.String {
|
||||
if trimmed := strings.TrimSpace(nv.String()); trimmed != "" {
|
||||
fileData["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
files = append(files, fileData)
|
||||
@@ -421,6 +439,37 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
if claims := extractCodexIDTokenClaims(auth); claims != nil {
|
||||
entry["id_token"] = claims
|
||||
}
|
||||
// Expose priority from Attributes (set by synthesizer from JSON "priority" field).
|
||||
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
|
||||
if p := strings.TrimSpace(authAttribute(auth, "priority")); p != "" {
|
||||
if parsed, err := strconv.Atoi(p); err == nil {
|
||||
entry["priority"] = parsed
|
||||
}
|
||||
} else if auth.Metadata != nil {
|
||||
if rawPriority, ok := auth.Metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
entry["priority"] = int(v)
|
||||
case int:
|
||||
entry["priority"] = v
|
||||
case string:
|
||||
if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
entry["priority"] = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Expose note from Attributes (set by synthesizer from JSON "note" field).
|
||||
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
|
||||
if note := strings.TrimSpace(authAttribute(auth, "note")); note != "" {
|
||||
entry["note"] = note
|
||||
} else if auth.Metadata != nil {
|
||||
if rawNote, ok := auth.Metadata["note"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(rawNote); trimmed != "" {
|
||||
entry["note"] = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
@@ -845,7 +894,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
@@ -857,6 +906,7 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
Note *string `json:"note"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
@@ -899,14 +949,32 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
changed = true
|
||||
}
|
||||
if req.Priority != nil {
|
||||
if req.Priority != nil || req.Note != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
if targetAuth.Attributes == nil {
|
||||
targetAuth.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
if req.Priority != nil {
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
delete(targetAuth.Attributes, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority)
|
||||
}
|
||||
}
|
||||
if req.Note != nil {
|
||||
trimmedNote := strings.TrimSpace(*req.Note)
|
||||
if trimmedNote == "" {
|
||||
delete(targetAuth.Metadata, "note")
|
||||
delete(targetAuth.Attributes, "note")
|
||||
} else {
|
||||
targetAuth.Metadata["note"] = trimmedNote
|
||||
targetAuth.Attributes["note"] = trimmedNote
|
||||
}
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
@@ -999,6 +1067,165 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
func gitLabBaseURLFromRequest(c *gin.Context) string {
|
||||
if c != nil {
|
||||
if raw := strings.TrimSpace(c.Query("base_url")); raw != "" {
|
||||
return gitlabauth.NormalizeBaseURL(raw)
|
||||
}
|
||||
}
|
||||
if raw := strings.TrimSpace(os.Getenv("GITLAB_BASE_URL")); raw != "" {
|
||||
return gitlabauth.NormalizeBaseURL(raw)
|
||||
}
|
||||
return gitlabauth.DefaultBaseURL
|
||||
}
|
||||
|
||||
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
|
||||
metadata := map[string]any{
|
||||
"type": "gitlab",
|
||||
"auth_method": strings.TrimSpace(mode),
|
||||
"base_url": gitlabauth.NormalizeBaseURL(baseURL),
|
||||
"last_refresh": time.Now().UTC().Format(time.RFC3339),
|
||||
"refresh_interval_seconds": 240,
|
||||
}
|
||||
if tokenResp != nil {
|
||||
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
|
||||
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
|
||||
metadata["refresh_token"] = refreshToken
|
||||
}
|
||||
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
|
||||
metadata["token_type"] = tokenType
|
||||
}
|
||||
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
|
||||
metadata["scope"] = scope
|
||||
}
|
||||
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
|
||||
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
mergeGitLabDirectAccessMetadata(metadata, direct)
|
||||
return metadata
|
||||
}
|
||||
|
||||
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
|
||||
if metadata == nil || direct == nil {
|
||||
return
|
||||
}
|
||||
if base := strings.TrimSpace(direct.BaseURL); base != "" {
|
||||
metadata["duo_gateway_base_url"] = base
|
||||
}
|
||||
if token := strings.TrimSpace(direct.Token); token != "" {
|
||||
metadata["duo_gateway_token"] = token
|
||||
}
|
||||
if direct.ExpiresAt > 0 {
|
||||
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
|
||||
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
|
||||
now := time.Now().UTC()
|
||||
if ttl := expiry.Sub(now); ttl > 0 {
|
||||
interval := int(ttl.Seconds()) / 2
|
||||
switch {
|
||||
case interval < 60:
|
||||
interval = 60
|
||||
case interval > 240:
|
||||
interval = 240
|
||||
}
|
||||
metadata["refresh_interval_seconds"] = interval
|
||||
}
|
||||
}
|
||||
if len(direct.Headers) > 0 {
|
||||
headers := make(map[string]string, len(direct.Headers))
|
||||
for key, value := range direct.Headers {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
headers[key] = value
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
metadata["duo_gateway_headers"] = headers
|
||||
}
|
||||
}
|
||||
if direct.ModelDetails != nil {
|
||||
modelDetails := map[string]any{}
|
||||
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
|
||||
modelDetails["model_provider"] = provider
|
||||
metadata["model_provider"] = provider
|
||||
}
|
||||
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
|
||||
modelDetails["model_name"] = model
|
||||
metadata["model_name"] = model
|
||||
}
|
||||
if len(modelDetails) > 0 {
|
||||
metadata["model_details"] = modelDetails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func primaryGitLabEmail(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
if value := strings.TrimSpace(user.Email); value != "" {
|
||||
return value
|
||||
}
|
||||
return strings.TrimSpace(user.PublicEmail)
|
||||
}
|
||||
|
||||
func gitLabAccountIdentifier(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return "user"
|
||||
}
|
||||
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return "user"
|
||||
}
|
||||
|
||||
func sanitizeGitLabFileName(value string) string {
|
||||
value = strings.TrimSpace(strings.ToLower(value))
|
||||
if value == "" {
|
||||
return "user"
|
||||
}
|
||||
var builder strings.Builder
|
||||
lastDash := false
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r >= '0' && r <= '9':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
default:
|
||||
if !lastDash {
|
||||
builder.WriteRune('-')
|
||||
lastDash = true
|
||||
}
|
||||
}
|
||||
}
|
||||
result := strings.Trim(builder.String(), "-")
|
||||
if result == "" {
|
||||
return "user"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func maskGitLabToken(token string) string {
|
||||
trimmed := strings.TrimSpace(token)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if len(trimmed) <= 8 {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
@@ -1312,12 +1539,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
@@ -1326,7 +1553,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
ts.Auto = false
|
||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
@@ -1337,19 +1564,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1362,13 +1589,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1549,6 +1776,263 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestGitLabToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing GitLab Duo authentication...")
|
||||
|
||||
baseURL := gitLabBaseURLFromRequest(c)
|
||||
clientID := strings.TrimSpace(c.Query("client_id"))
|
||||
clientSecret := strings.TrimSpace(c.Query("client_secret"))
|
||||
if clientID == "" {
|
||||
clientID = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_ID"))
|
||||
}
|
||||
if clientSecret == "" {
|
||||
clientSecret = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_SECRET"))
|
||||
}
|
||||
if clientID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "gitlab client_id is required"})
|
||||
return
|
||||
}
|
||||
|
||||
pkceCodes, err := gitlabauth.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate GitLab PKCE codes: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
||||
return
|
||||
}
|
||||
|
||||
state, err := misc.GenerateRandomState()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate GitLab state parameter: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := gitlabauth.RedirectURL(gitlabauth.DefaultCallbackPort)
|
||||
authClient := gitlabauth.NewAuthClient(h.cfg)
|
||||
authURL, err := authClient.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate GitLab authorization URL: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "gitlab")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/gitlab/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute gitlab callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(gitlabauth.DefaultCallbackPort, "gitlab", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start gitlab callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarderInstance(gitlabauth.DefaultCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gitlab-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var code string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "gitlab") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("gitlab oauth flow timed out")
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return
|
||||
}
|
||||
if data, errRead := os.ReadFile(waitFile); errRead == nil {
|
||||
var payload map[string]string
|
||||
_ = json.Unmarshal(data, &payload)
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
SetOAuthSessionError(state, errStr)
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != state {
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
return
|
||||
}
|
||||
code = strings.TrimSpace(payload["code"])
|
||||
if code == "" {
|
||||
SetOAuthSessionError(state, "Authorization code missing")
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
tokenResp, errExchange := authClient.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, code, pkceCodes.CodeVerifier)
|
||||
if errExchange != nil {
|
||||
log.Errorf("Failed to exchange GitLab authorization code: %v", errExchange)
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
return
|
||||
}
|
||||
|
||||
user, errUser := authClient.GetCurrentUser(ctx, baseURL, tokenResp.AccessToken)
|
||||
if errUser != nil {
|
||||
log.Errorf("Failed to fetch GitLab user profile: %v", errUser)
|
||||
SetOAuthSessionError(state, "Failed to fetch account profile")
|
||||
return
|
||||
}
|
||||
|
||||
direct, errDirect := authClient.FetchDirectAccess(ctx, baseURL, tokenResp.AccessToken)
|
||||
if errDirect != nil {
|
||||
log.Errorf("Failed to fetch GitLab direct access metadata: %v", errDirect)
|
||||
SetOAuthSessionError(state, "Failed to fetch GitLab Duo access")
|
||||
return
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata["oauth_client_id"] = clientID
|
||||
if clientSecret != "" {
|
||||
metadata["oauth_client_secret"] = clientSecret
|
||||
}
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := primaryGitLabEmail(user); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "gitlab",
|
||||
FileName: fileName,
|
||||
Label: identifier,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save GitLab auth record: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("GitLab Duo authentication successful. Token saved to %s\n", savedPath)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("gitlab")
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestGitLabPATToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
var payload struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
PersonalAccessToken string `json:"personal_access_token"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := gitlabauth.NormalizeBaseURL(strings.TrimSpace(payload.BaseURL))
|
||||
if baseURL == "" {
|
||||
baseURL = gitLabBaseURLFromRequest(nil)
|
||||
}
|
||||
pat := strings.TrimSpace(payload.PersonalAccessToken)
|
||||
if pat == "" {
|
||||
pat = strings.TrimSpace(payload.Token)
|
||||
}
|
||||
if pat == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "personal_access_token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
authClient := gitlabauth.NewAuthClient(h.cfg)
|
||||
|
||||
user, err := authClient.GetCurrentUser(ctx, baseURL, pat)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
|
||||
return
|
||||
}
|
||||
patSelf, err := authClient.GetPersonalAccessTokenSelf(ctx, baseURL, pat)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
|
||||
return
|
||||
}
|
||||
direct, err := authClient.FetchDirectAccess(ctx, baseURL, pat)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
|
||||
metadata["auth_kind"] = "personal_access_token"
|
||||
metadata["personal_access_token"] = pat
|
||||
metadata["token_preview"] = maskGitLabToken(pat)
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := primaryGitLabEmail(user); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
if patSelf != nil {
|
||||
if name := strings.TrimSpace(patSelf.Name); name != "" {
|
||||
metadata["pat_name"] = name
|
||||
}
|
||||
if len(patSelf.Scopes) > 0 {
|
||||
metadata["pat_scopes"] = append([]string(nil), patSelf.Scopes...)
|
||||
}
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "gitlab",
|
||||
FileName: fileName,
|
||||
Label: identifier + " (PAT)",
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
savedPath, err := h.saveTokenRecord(ctx, record)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to save authentication tokens"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": "ok",
|
||||
"saved_path": savedPath,
|
||||
"username": strings.TrimSpace(user.Username),
|
||||
"email": primaryGitLabEmail(user),
|
||||
"token_label": identifier,
|
||||
}
|
||||
if direct != nil && direct.ModelDetails != nil {
|
||||
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
|
||||
response["model_provider"] = provider
|
||||
}
|
||||
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
|
||||
response["model_name"] = model
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("GitLab Duo PAT authentication successful. Token saved to %s\n", savedPath)
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
@@ -2019,17 +2503,20 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
if label == "" {
|
||||
label = username
|
||||
}
|
||||
metadata, errMeta := copilotTokenMetadata(tokenStorage)
|
||||
if errMeta != nil {
|
||||
log.Errorf("Failed to build token metadata: %v", errMeta)
|
||||
SetOAuthSessionError(state, "Failed to build token metadata")
|
||||
return
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "github-copilot",
|
||||
Label: label,
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{
|
||||
"email": userInfo.Email,
|
||||
"username": username,
|
||||
"name": userInfo.Name,
|
||||
},
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
@@ -2054,6 +2541,21 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func copilotTokenMetadata(storage *copilot.CopilotTokenStorage) (map[string]any, error) {
|
||||
if storage == nil {
|
||||
return nil, fmt.Errorf("token storage is nil")
|
||||
}
|
||||
payload, errMarshal := json.Marshal(storage)
|
||||
if errMarshal != nil {
|
||||
return nil, fmt.Errorf("marshal token storage: %w", errMarshal)
|
||||
}
|
||||
metadata := make(map[string]any)
|
||||
if errUnmarshal := json.Unmarshal(payload, &metadata); errUnmarshal != nil {
|
||||
return nil, fmt.Errorf("unmarshal token storage: %w", errUnmarshal)
|
||||
}
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
|
||||
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/api/v4/user":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": 42,
|
||||
"username": "gitlab-user",
|
||||
"name": "GitLab User",
|
||||
"email": "gitlab@example.com",
|
||||
})
|
||||
case "/api/v4/personal_access_tokens/self":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": 7,
|
||||
"name": "management-center",
|
||||
"scopes": []string{"api", "read_user"},
|
||||
"user_id": 42,
|
||||
})
|
||||
case "/api/v4/code_suggestions/direct_access":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1893456000,
|
||||
"headers": map[string]string{
|
||||
"X-Gitlab-Realm": "saas",
|
||||
},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
|
||||
h.tokenStore = store
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.RequestGitLabPATToken(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if got := resp["status"]; got != "ok" {
|
||||
t.Fatalf("status = %#v, want ok", got)
|
||||
}
|
||||
if got := resp["model_provider"]; got != "anthropic" {
|
||||
t.Fatalf("model_provider = %#v, want anthropic", got)
|
||||
}
|
||||
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
|
||||
}
|
||||
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
if len(store.items) != 1 {
|
||||
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
|
||||
}
|
||||
var saved *coreauth.Auth
|
||||
for _, item := range store.items {
|
||||
saved = item
|
||||
}
|
||||
if saved == nil {
|
||||
t.Fatal("expected saved auth record")
|
||||
}
|
||||
if saved.Provider != "gitlab" {
|
||||
t.Fatalf("provider = %q, want gitlab", saved.Provider)
|
||||
}
|
||||
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
|
||||
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
|
||||
}
|
||||
if got := saved.Metadata["model_provider"]; got != "anthropic" {
|
||||
t.Fatalf("saved model_provider = %#v, want anthropic", got)
|
||||
}
|
||||
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
|
||||
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
state := "gitlab-state-123"
|
||||
RegisterOAuthSession(state, "gitlab")
|
||||
t.Cleanup(func() { CompleteOAuthSession(state) })
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.PostOAuthCallback(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("read callback file: %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
t.Fatalf("decode callback payload: %v", err)
|
||||
}
|
||||
if got := payload["code"]; got != "test-code" {
|
||||
t.Fatalf("callback code = %q, want test-code", got)
|
||||
}
|
||||
if got := payload["state"]; got != state {
|
||||
t.Fatalf("callback state = %q, want %q", got, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
|
||||
provider, err := NormalizeOAuthProvider("gitlab")
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
|
||||
}
|
||||
if provider != "gitlab" {
|
||||
t.Fatalf("provider = %q, want gitlab", provider)
|
||||
}
|
||||
}
|
||||
@@ -509,8 +509,12 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
||||
}
|
||||
for i := range arr {
|
||||
normalizeVertexCompatKey(&arr[i])
|
||||
if arr[i].APIKey == "" {
|
||||
c.JSON(400, gin.H{"error": fmt.Sprintf("vertex-api-key[%d].api-key is required", i)})
|
||||
return
|
||||
}
|
||||
}
|
||||
h.cfg.VertexCompatAPIKey = arr
|
||||
h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...)
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
@@ -228,6 +228,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gitlab":
|
||||
return "gitlab", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
|
||||
49
internal/api/handlers/management/test_store_test.go
Normal file
49
internal/api/handlers/management/test_store_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
type memoryAuthStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]*coreauth.Auth
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||
for _, item := range s.items {
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.items == nil {
|
||||
s.items = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
s.items[auth.ID] = auth
|
||||
return auth.ID, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.items, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) SetBaseDir(string) {}
|
||||
@@ -403,6 +403,20 @@ func (s *Server) setupRoutes() {
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/gitlab/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gitlab", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/google/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
@@ -658,6 +672,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
||||
mgmt.GET("/gitlab-auth-url", s.mgmt.RequestGitLabToken)
|
||||
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
|
||||
@@ -4,12 +4,12 @@ package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/proxy"
|
||||
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
|
||||
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
||||
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||
var dialer proxy.Dialer = proxy.Direct
|
||||
if cfg != nil && cfg.ProxyURL != "" {
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
|
||||
} else {
|
||||
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
|
||||
} else {
|
||||
dialer = pDialer
|
||||
}
|
||||
if cfg != nil {
|
||||
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
|
||||
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||
dialer = proxyDialer
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
@@ -20,9 +18,9 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
var transport *http.Transport
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
// Handle SOCKS5 proxy.
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
auth := &proxy.Auth{User: username, Password: password}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||
}
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
// Handle HTTP/HTTPS proxy.
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
if transport != nil {
|
||||
proxyClient := &http.Client{Transport: transport}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||
}
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("%v", errBuild)
|
||||
} else if transport != nil {
|
||||
proxyClient := &http.Client{Transport: transport}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Configure the OAuth2 client.
|
||||
conf := &oauth2.Config{
|
||||
ClientID: ClientID,
|
||||
@@ -327,6 +305,9 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
var manualInputCh <-chan string
|
||||
var manualInputErrCh <-chan error
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
@@ -348,13 +329,14 @@ waitForCallback:
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := misc.ParseOAuthCallback(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||
continue
|
||||
case input := <-manualInputCh:
|
||||
manualInputCh = nil
|
||||
manualInputErrCh = nil
|
||||
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
@@ -367,6 +349,8 @@ waitForCallback:
|
||||
}
|
||||
authCode = parsed.Code
|
||||
break waitForCallback
|
||||
case errManual := <-manualInputErrCh:
|
||||
return nil, errManual
|
||||
case <-timeoutTimer.C:
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
}
|
||||
|
||||
492
internal/auth/gitlab/gitlab.go
Normal file
492
internal/auth/gitlab/gitlab.go
Normal file
@@ -0,0 +1,492 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultBaseURL = "https://gitlab.com"
|
||||
DefaultCallbackPort = 17171
|
||||
defaultOAuthScope = "api read_user"
|
||||
)
|
||||
|
||||
type PKCECodes struct {
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
type OAuthResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
type OAuthServer struct {
|
||||
server *http.Server
|
||||
port int
|
||||
resultChan chan *OAuthResult
|
||||
errorChan chan error
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
PublicEmail string `json:"public_email"`
|
||||
}
|
||||
|
||||
type PersonalAccessTokenSelf struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
ModelProvider string `json:"model_provider"`
|
||||
ModelName string `json:"model_name"`
|
||||
}
|
||||
|
||||
type DirectAccessResponse struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Token string `json:"token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
ModelDetails *ModelDetails `json:"model_details,omitempty"`
|
||||
}
|
||||
|
||||
type DiscoveredModel struct {
|
||||
ModelProvider string
|
||||
ModelName string
|
||||
}
|
||||
|
||||
type AuthClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewAuthClient(cfg *config.Config) *AuthClient {
|
||||
client := &http.Client{}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
return &AuthClient{httpClient: client}
|
||||
}
|
||||
|
||||
func NormalizeBaseURL(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return DefaultBaseURL
|
||||
}
|
||||
if !strings.Contains(value, "://") {
|
||||
value = "https://" + value
|
||||
}
|
||||
value = strings.TrimRight(value, "/")
|
||||
return value
|
||||
}
|
||||
|
||||
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
|
||||
if token == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
|
||||
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
|
||||
}
|
||||
if token.ExpiresIn > 0 {
|
||||
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||
verifierBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(verifierBytes); err != nil {
|
||||
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
|
||||
}
|
||||
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return &PKCECodes{
|
||||
CodeVerifier: verifier,
|
||||
CodeChallenge: challenge,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewOAuthServer(port int) *OAuthServer {
|
||||
return &OAuthServer{
|
||||
port: port,
|
||||
resultChan: make(chan *OAuthResult, 1),
|
||||
errorChan: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("gitlab oauth server already running")
|
||||
}
|
||||
if !s.isPortAvailable() {
|
||||
return fmt.Errorf("port %d is already in use", s.port)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
s.running = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.errorChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.running || s.server == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
s.running = false
|
||||
s.server = nil
|
||||
}()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||
select {
|
||||
case result := <-s.resultChan:
|
||||
return result, nil
|
||||
case err := <-s.errorChan:
|
||||
return nil, err
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
query := r.URL.Query()
|
||||
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||
s.sendResult(&OAuthResult{Error: errParam})
|
||||
http.Error(w, errParam, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
state := strings.TrimSpace(query.Get("state"))
|
||||
if code == "" || state == "" {
|
||||
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
|
||||
http.Error(w, "missing code or state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
|
||||
}
|
||||
|
||||
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||
select {
|
||||
case s.resultChan <- result:
|
||||
default:
|
||||
log.Debug("gitlab oauth result channel full, dropping callback result")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) isPortAvailable() bool {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = listener.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func RedirectURL(port int) string {
|
||||
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
|
||||
}
|
||||
|
||||
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
|
||||
if pkce == nil {
|
||||
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
|
||||
}
|
||||
if strings.TrimSpace(clientID) == "" {
|
||||
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
|
||||
}
|
||||
baseURL = NormalizeBaseURL(baseURL)
|
||||
params := url.Values{
|
||||
"client_id": {strings.TrimSpace(clientID)},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"scope": {defaultOAuthScope},
|
||||
"state": {strings.TrimSpace(state)},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
}
|
||||
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {strings.TrimSpace(clientID)},
|
||||
"code": {strings.TrimSpace(code)},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {strings.TrimSpace(codeVerifier)},
|
||||
}
|
||||
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||
form.Set("client_secret", secret)
|
||||
}
|
||||
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||
}
|
||||
|
||||
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||
}
|
||||
if clientID = strings.TrimSpace(clientID); clientID != "" {
|
||||
form.Set("client_id", clientID)
|
||||
}
|
||||
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||
form.Set("client_secret", secret)
|
||||
}
|
||||
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||
}
|
||||
|
||||
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var token TokenResponse
|
||||
if err := json.Unmarshal(body, &token); err != nil {
|
||||
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var user User
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var pat PersonalAccessTokenSelf
|
||||
if err := json.Unmarshal(body, &pat); err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
|
||||
}
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var direct DirectAccessResponse
|
||||
if err := json.Unmarshal(body, &direct); err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
|
||||
}
|
||||
if direct.Headers == nil {
|
||||
direct.Headers = make(map[string]string)
|
||||
}
|
||||
return &direct, nil
|
||||
}
|
||||
|
||||
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
|
||||
if len(metadata) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
models := make([]DiscoveredModel, 0, 4)
|
||||
seen := make(map[string]struct{})
|
||||
appendModel := func(provider, name string) {
|
||||
provider = strings.TrimSpace(provider)
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(name)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
models = append(models, DiscoveredModel{
|
||||
ModelProvider: provider,
|
||||
ModelName: name,
|
||||
})
|
||||
}
|
||||
|
||||
if raw, ok := metadata["model_details"]; ok {
|
||||
appendDiscoveredModels(raw, appendModel)
|
||||
}
|
||||
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
|
||||
|
||||
for _, key := range []string{"models", "supported_models", "discovered_models"} {
|
||||
if raw, ok := metadata[key]; ok {
|
||||
appendDiscoveredModels(raw, appendModel)
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
|
||||
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
|
||||
if nested, ok := typed["models"]; ok {
|
||||
appendDiscoveredModels(nested, appendModel)
|
||||
}
|
||||
case []any:
|
||||
for _, item := range typed {
|
||||
appendDiscoveredModels(item, appendModel)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range typed {
|
||||
appendModel("", item)
|
||||
}
|
||||
case string:
|
||||
appendModel("", typed)
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(raw any) string {
|
||||
switch typed := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(typed.String())
|
||||
case json.Number:
|
||||
return typed.String()
|
||||
case int:
|
||||
return strconv.Itoa(typed)
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10)
|
||||
case float64:
|
||||
return strconv.FormatInt(int64(typed), 10)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
138
internal/auth/gitlab/gitlab_test.go
Normal file
138
internal/auth/gitlab/gitlab_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthClientGenerateAuthURLIncludesPKCE(t *testing.T) {
|
||||
client := NewAuthClient(nil)
|
||||
pkce, err := GeneratePKCECodes()
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePKCECodes() error = %v", err)
|
||||
}
|
||||
|
||||
rawURL, err := client.GenerateAuthURL("https://gitlab.example.com", "client-id", RedirectURL(17171), "state-123", pkce)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL() error = %v", err)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse(authURL) error = %v", err)
|
||||
}
|
||||
if got := parsed.Path; got != "/oauth/authorize" {
|
||||
t.Fatalf("expected /oauth/authorize path, got %q", got)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("client_id"); got != "client-id" {
|
||||
t.Fatalf("expected client_id, got %q", got)
|
||||
}
|
||||
if got := query.Get("scope"); got != defaultOAuthScope {
|
||||
t.Fatalf("expected scope %q, got %q", defaultOAuthScope, got)
|
||||
}
|
||||
if got := query.Get("code_challenge_method"); got != "S256" {
|
||||
t.Fatalf("expected PKCE method S256, got %q", got)
|
||||
}
|
||||
if got := query.Get("code_challenge"); got == "" {
|
||||
t.Fatal("expected non-empty code_challenge")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthClientExchangeCodeForTokens(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/token" {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm() error = %v", err)
|
||||
}
|
||||
if got := r.Form.Get("grant_type"); got != "authorization_code" {
|
||||
t.Fatalf("expected authorization_code grant, got %q", got)
|
||||
}
|
||||
if got := r.Form.Get("code_verifier"); got != "verifier-123" {
|
||||
t.Fatalf("expected code_verifier, got %q", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"token_type": "Bearer",
|
||||
"scope": "api read_user",
|
||||
"created_at": 1710000000,
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewAuthClient(nil)
|
||||
token, err := client.ExchangeCodeForTokens(context.Background(), srv.URL, "client-id", "client-secret", RedirectURL(17171), "auth-code", "verifier-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCodeForTokens() error = %v", err)
|
||||
}
|
||||
if token.AccessToken != "oauth-access" {
|
||||
t.Fatalf("expected access token, got %q", token.AccessToken)
|
||||
}
|
||||
if token.RefreshToken != "oauth-refresh" {
|
||||
t.Fatalf("expected refresh token, got %q", token.RefreshToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDiscoveredModels(t *testing.T) {
|
||||
models := ExtractDiscoveredModels(map[string]any{
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
"supported_models": []any{
|
||||
map[string]any{"model_provider": "openai", "model_name": "gpt-4.1"},
|
||||
"claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
if len(models) != 2 {
|
||||
t.Fatalf("expected 2 unique models, got %d", len(models))
|
||||
}
|
||||
if models[0].ModelName != "claude-sonnet-4-5" {
|
||||
t.Fatalf("unexpected first model %q", models[0].ModelName)
|
||||
}
|
||||
if models[1].ModelName != "gpt-4.1" {
|
||||
t.Fatalf("unexpected second model %q", models[1].ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchDirectAccessDecodesModelDetails(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); !strings.Contains(got, "token-123") {
|
||||
t.Fatalf("expected bearer token, got %q", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1710003600,
|
||||
"headers": map[string]string{
|
||||
"X-Gitlab-Realm": "saas",
|
||||
},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewAuthClient(nil)
|
||||
direct, err := client.FetchDirectAccess(context.Background(), srv.URL, "token-123")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchDirectAccess() error = %v", err)
|
||||
}
|
||||
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected model details, got %+v", direct.ModelDetails)
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewKiroAuthenticator(),
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
69
internal/cmd/gitlab_login.go
Normal file
69
internal/cmd/gitlab_login.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
)
|
||||
|
||||
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{
|
||||
"login_mode": "oauth",
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("GitLab Duo authentication successful!")
|
||||
}
|
||||
|
||||
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
Metadata: map[string]string{
|
||||
"login_mode": "pat",
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("GitLab Duo PAT authentication successful!")
|
||||
}
|
||||
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
configYAML := []byte(`
|
||||
codex-header-defaults:
|
||||
user-agent: " my-codex-client/1.0 "
|
||||
beta-features: " feature-a,feature-b "
|
||||
`)
|
||||
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfigOptional(configPath, false)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||
}
|
||||
|
||||
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
|
||||
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
|
||||
}
|
||||
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
|
||||
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
|
||||
}
|
||||
}
|
||||
@@ -101,6 +101,10 @@ type Config struct {
|
||||
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
||||
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
||||
|
||||
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
|
||||
// These are used only when the client does not send its own headers.
|
||||
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
|
||||
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
|
||||
@@ -150,6 +154,14 @@ type ClaudeHeaderDefaults struct {
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// CodexHeaderDefaults configures fallback header values injected into Codex
|
||||
// model requests for OAuth/file-backed auth when the client omits them.
|
||||
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
|
||||
type CodexHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
@@ -673,12 +685,15 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
// Sanitize Vertex-compatible API keys: drop entries without base-url
|
||||
// Sanitize Vertex-compatible API keys.
|
||||
cfg.SanitizeVertexCompatKeys()
|
||||
|
||||
// Sanitize Codex keys: drop entries without base-url
|
||||
cfg.SanitizeCodexKeys()
|
||||
|
||||
// Sanitize Codex header defaults.
|
||||
cfg.SanitizeCodexHeaderDefaults()
|
||||
|
||||
// Sanitize Claude key headers
|
||||
cfg.SanitizeClaudeKeys()
|
||||
|
||||
@@ -771,6 +786,16 @@ func payloadRawString(value any) ([]byte, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
|
||||
// configured Codex header fallback values.
|
||||
func (cfg *Config) SanitizeCodexHeaderDefaults() {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
|
||||
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
|
||||
@@ -20,9 +20,9 @@ type VertexCompatKey struct {
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// BaseURL optionally overrides the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
// When empty, requests fall back to the default Vertex API base URL.
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
// ProxyURL optionally overrides the global proxy for this API key.
|
||||
@@ -71,10 +71,6 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
continue
|
||||
}
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
|
||||
@@ -30,6 +30,23 @@ type OAuthCallback struct {
|
||||
ErrorDescription string
|
||||
}
|
||||
|
||||
// AsyncPrompt runs a prompt function in a goroutine and returns channels for
|
||||
// the result. The returned channels are buffered (size 1) so the goroutine can
|
||||
// complete even if the caller abandons the channels.
|
||||
func AsyncPrompt(promptFn func(string) (string, error), message string) (<-chan string, <-chan error) {
|
||||
inputCh := make(chan string, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
input, err := promptFn(message)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
inputCh <- input
|
||||
}()
|
||||
return inputCh, errCh
|
||||
}
|
||||
|
||||
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
||||
// It returns nil when the input is empty.
|
||||
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
||||
|
||||
@@ -1,12 +1,105 @@
|
||||
// Package registry provides model definitions and lookup helpers for various AI providers.
|
||||
// Static model metadata is stored in model_definitions_static_data.go.
|
||||
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// staticModelsJSON mirrors the top-level structure of models.json.
|
||||
type staticModelsJSON struct {
|
||||
Claude []*ModelInfo `json:"claude"`
|
||||
Gemini []*ModelInfo `json:"gemini"`
|
||||
Vertex []*ModelInfo `json:"vertex"`
|
||||
GeminiCLI []*ModelInfo `json:"gemini-cli"`
|
||||
AIStudio []*ModelInfo `json:"aistudio"`
|
||||
CodexFree []*ModelInfo `json:"codex-free"`
|
||||
CodexTeam []*ModelInfo `json:"codex-team"`
|
||||
CodexPlus []*ModelInfo `json:"codex-plus"`
|
||||
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||
Qwen []*ModelInfo `json:"qwen"`
|
||||
IFlow []*ModelInfo `json:"iflow"`
|
||||
Kimi []*ModelInfo `json:"kimi"`
|
||||
Antigravity []*ModelInfo `json:"antigravity"`
|
||||
}
|
||||
|
||||
// GetClaudeModels returns the standard Claude model definitions.
|
||||
func GetClaudeModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Claude)
|
||||
}
|
||||
|
||||
// GetGeminiModels returns the standard Gemini model definitions.
|
||||
func GetGeminiModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Gemini)
|
||||
}
|
||||
|
||||
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
|
||||
func GetGeminiVertexModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Vertex)
|
||||
}
|
||||
|
||||
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
|
||||
func GetGeminiCLIModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().GeminiCLI)
|
||||
}
|
||||
|
||||
// GetAIStudioModels returns model definitions for AI Studio.
|
||||
func GetAIStudioModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().AIStudio)
|
||||
}
|
||||
|
||||
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
|
||||
func GetCodexFreeModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexFree)
|
||||
}
|
||||
|
||||
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
|
||||
func GetCodexTeamModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexTeam)
|
||||
}
|
||||
|
||||
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
|
||||
func GetCodexPlusModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexPlus)
|
||||
}
|
||||
|
||||
// GetCodexProModels returns model definitions for the Codex pro plan tier.
|
||||
func GetCodexProModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().CodexPro)
|
||||
}
|
||||
|
||||
// GetQwenModels returns the standard Qwen model definitions.
|
||||
func GetQwenModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Qwen)
|
||||
}
|
||||
|
||||
// GetIFlowModels returns the standard iFlow model definitions.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().IFlow)
|
||||
}
|
||||
|
||||
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
|
||||
func GetKimiModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Kimi)
|
||||
}
|
||||
|
||||
// GetAntigravityModels returns the standard Antigravity model definitions.
|
||||
func GetAntigravityModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Antigravity)
|
||||
}
|
||||
|
||||
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*ModelInfo, len(models))
|
||||
for i, m := range models {
|
||||
out[i] = cloneModelInfo(m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
||||
// It returns nil when the channel is unknown.
|
||||
//
|
||||
@@ -20,7 +113,6 @@ import (
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kimi
|
||||
// - kiro
|
||||
// - kilo
|
||||
// - github-copilot
|
||||
// - amazonq
|
||||
@@ -39,7 +131,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
case "aistudio":
|
||||
return GetAIStudioModels()
|
||||
case "codex":
|
||||
return GetOpenAIModels()
|
||||
return GetCodexProModels()
|
||||
case "qwen":
|
||||
return GetQwenModels()
|
||||
case "iflow":
|
||||
@@ -55,28 +147,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
case "amazonq":
|
||||
return GetAmazonQModels()
|
||||
case "antigravity":
|
||||
cfg := GetAntigravityModelConfig()
|
||||
if len(cfg) == 0 {
|
||||
return nil
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(cfg))
|
||||
for modelID, entry := range cfg {
|
||||
if modelID == "" || entry == nil {
|
||||
continue
|
||||
}
|
||||
models = append(models, &ModelInfo{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
OwnedBy: "antigravity",
|
||||
Type: "antigravity",
|
||||
Thinking: entry.Thinking,
|
||||
MaxCompletionTokens: entry.MaxCompletionTokens,
|
||||
})
|
||||
}
|
||||
sort.Slice(models, func(i, j int) bool {
|
||||
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
|
||||
})
|
||||
return models
|
||||
return GetAntigravityModels()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -89,16 +160,18 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := getModels()
|
||||
allModels := [][]*ModelInfo{
|
||||
GetClaudeModels(),
|
||||
GetGeminiModels(),
|
||||
GetGeminiVertexModels(),
|
||||
GetGeminiCLIModels(),
|
||||
GetAIStudioModels(),
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
GetKimiModels(),
|
||||
data.Claude,
|
||||
data.Gemini,
|
||||
data.Vertex,
|
||||
data.GeminiCLI,
|
||||
data.AIStudio,
|
||||
data.CodexPro,
|
||||
data.Qwen,
|
||||
data.IFlow,
|
||||
data.Kimi,
|
||||
data.Antigravity,
|
||||
GetGitHubCopilotModels(),
|
||||
GetKiroModels(),
|
||||
GetKiloModels(),
|
||||
@@ -107,20 +180,11 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
if m != nil && m.ID == modelID {
|
||||
return m
|
||||
return cloneModelInfo(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check Antigravity static config
|
||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
||||
return &ModelInfo{
|
||||
ID: modelID,
|
||||
Thinking: cfg.Thinking,
|
||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -64,6 +64,11 @@ type ModelInfo struct {
|
||||
UserDefined bool `json:"-"`
|
||||
}
|
||||
|
||||
type availableModelsCacheEntry struct {
|
||||
models []map[string]any
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
||||
// Values are interpreted in provider-native token units.
|
||||
type ThinkingSupport struct {
|
||||
@@ -118,6 +123,8 @@ type ModelRegistry struct {
|
||||
clientProviders map[string]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
mutex *sync.RWMutex
|
||||
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
|
||||
availableModelsCache map[string]availableModelsCacheEntry
|
||||
// hook is an optional callback sink for model registration changes
|
||||
hook ModelRegistryHook
|
||||
}
|
||||
@@ -130,15 +137,28 @@ var registryOnce sync.Once
|
||||
func GetGlobalRegistry() *ModelRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
availableModelsCache: make(map[string]availableModelsCacheEntry),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
|
||||
if r.availableModelsCache == nil {
|
||||
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
|
||||
if len(r.availableModelsCache) == 0 {
|
||||
return
|
||||
}
|
||||
clear(r.availableModelsCache)
|
||||
}
|
||||
|
||||
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||
@@ -153,9 +173,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||
}
|
||||
|
||||
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||
return info
|
||||
return cloneModelInfo(info)
|
||||
}
|
||||
return LookupStaticModelInfo(modelID)
|
||||
return cloneModelInfo(LookupStaticModelInfo(modelID))
|
||||
}
|
||||
|
||||
// SetHook sets an optional hook for observing model registration changes.
|
||||
@@ -169,6 +189,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
|
||||
}
|
||||
|
||||
const defaultModelRegistryHookTimeout = 5 * time.Second
|
||||
const modelQuotaExceededWindow = 5 * time.Minute
|
||||
|
||||
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
||||
hook := r.hook
|
||||
@@ -213,6 +234,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
|
||||
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
provider := strings.ToLower(clientProvider)
|
||||
uniqueModelIDs := make([]string, 0, len(models))
|
||||
@@ -238,6 +260,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
delete(r.clientProviders, clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
misc.LogCredentialSeparator()
|
||||
return
|
||||
}
|
||||
@@ -265,6 +288,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
} else {
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
r.triggerModelsRegistered(provider, clientID, models)
|
||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||
misc.LogCredentialSeparator()
|
||||
@@ -367,6 +391,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
reg.LastUpdated = now
|
||||
// Re-registering an existing client/model binding starts a fresh registry
|
||||
// snapshot for that binding. Cooldown and suspension are transient
|
||||
// scheduling state and must not survive this reconciliation step.
|
||||
if reg.QuotaExceededClients != nil {
|
||||
delete(reg.QuotaExceededClients, clientID)
|
||||
}
|
||||
@@ -408,6 +435,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
r.triggerModelsRegistered(provider, clientID, models)
|
||||
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||
@@ -511,6 +539,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
||||
if len(model.SupportedOutputModalities) > 0 {
|
||||
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
copyThinking := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
copyModel.Thinking = ©Thinking
|
||||
}
|
||||
return ©Model
|
||||
}
|
||||
|
||||
@@ -540,6 +575,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.unregisterClientInternal(clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
}
|
||||
|
||||
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
||||
@@ -606,9 +642,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
@@ -620,9 +659,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
@@ -638,6 +679,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil {
|
||||
@@ -651,6 +693,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
||||
}
|
||||
registration.SuspendedClients[clientID] = reason
|
||||
registration.LastUpdated = time.Now()
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
if reason != "" {
|
||||
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
||||
} else {
|
||||
@@ -668,6 +711,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil || registration.SuspendedClients == nil {
|
||||
@@ -678,6 +722,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
||||
}
|
||||
delete(registration.SuspendedClients, clientID)
|
||||
registration.LastUpdated = time.Now()
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
||||
}
|
||||
|
||||
@@ -713,22 +758,51 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
|
||||
// Returns:
|
||||
// - []map[string]any: List of available models in the requested format
|
||||
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
now := time.Now()
|
||||
|
||||
models := make([]map[string]any, 0)
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
r.mutex.RLock()
|
||||
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||
models := cloneModelMaps(cache.models)
|
||||
r.mutex.RUnlock()
|
||||
return models
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.ensureAvailableModelsCacheLocked()
|
||||
|
||||
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||
return cloneModelMaps(cache.models)
|
||||
}
|
||||
|
||||
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
|
||||
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
|
||||
models: cloneModelMaps(models),
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
||||
models := make([]map[string]any, 0, len(r.models))
|
||||
var expiresAt time.Time
|
||||
|
||||
for _, registration := range r.models {
|
||||
// Check if model has any non-quota-exceeded clients
|
||||
availableClients := registration.Count
|
||||
now := time.Now()
|
||||
|
||||
// Count clients that have exceeded quota but haven't recovered yet
|
||||
expiredClients := 0
|
||||
for _, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime == nil {
|
||||
continue
|
||||
}
|
||||
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
|
||||
if now.Before(recoveryAt) {
|
||||
expiredClients++
|
||||
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
||||
expiresAt = recoveryAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -749,7 +823,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
effectiveClients = 0
|
||||
}
|
||||
|
||||
// Include models that have available clients, or those solely cooling down.
|
||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||
model := r.convertModelToMap(registration.Info, handlerType)
|
||||
if model != nil {
|
||||
@@ -758,7 +831,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
return models, expiresAt
|
||||
}
|
||||
|
||||
func cloneModelMaps(models []map[string]any) []map[string]any {
|
||||
cloned := make([]map[string]any, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
cloned = append(cloned, nil)
|
||||
continue
|
||||
}
|
||||
copyModel := make(map[string]any, len(model))
|
||||
for key, value := range model {
|
||||
copyModel[key] = cloneModelMapValue(value)
|
||||
}
|
||||
cloned = append(cloned, copyModel)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneModelMapValue(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
copyMap := make(map[string]any, len(typed))
|
||||
for key, entry := range typed {
|
||||
copyMap[key] = cloneModelMapValue(entry)
|
||||
}
|
||||
return copyMap
|
||||
case []any:
|
||||
copySlice := make([]any, len(typed))
|
||||
for i, entry := range typed {
|
||||
copySlice[i] = cloneModelMapValue(entry)
|
||||
}
|
||||
return copySlice
|
||||
case []string:
|
||||
return append([]string(nil), typed...)
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
||||
@@ -822,7 +932,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
return nil
|
||||
}
|
||||
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
now := time.Now()
|
||||
result := make([]*ModelInfo, 0, len(providerModels))
|
||||
|
||||
@@ -844,7 +953,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||
continue
|
||||
}
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
@@ -874,11 +983,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
||||
|
||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||
if entry.info != nil {
|
||||
result = append(result, entry.info)
|
||||
result = append(result, cloneModelInfo(entry.info))
|
||||
continue
|
||||
}
|
||||
if ok && registration != nil && registration.Info != nil {
|
||||
result = append(result, registration.Info)
|
||||
result = append(result, cloneModelInfo(registration.Info))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -898,12 +1007,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
|
||||
// Count clients that have exceeded quota but haven't recovered yet
|
||||
expiredClients := 0
|
||||
for _, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
@@ -987,13 +1095,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
||||
if reg.Providers != nil {
|
||||
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||
return info
|
||||
return cloneModelInfo(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback to global info (last registered)
|
||||
return reg.Info
|
||||
return cloneModelInfo(reg.Info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1033,7 +1141,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
result["supported_parameters"] = model.SupportedParameters
|
||||
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if len(model.SupportedEndpoints) > 0 {
|
||||
result["supported_endpoints"] = model.SupportedEndpoints
|
||||
@@ -1094,13 +1202,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
result["outputTokenLimit"] = model.OutputTokenLimit
|
||||
}
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
||||
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedInputModalities) > 0 {
|
||||
result["supportedInputModalities"] = model.SupportedInputModalities
|
||||
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
|
||||
}
|
||||
if len(model.SupportedOutputModalities) > 0 {
|
||||
result["supportedOutputModalities"] = model.SupportedOutputModalities
|
||||
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -1129,16 +1237,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
invalidated := false
|
||||
|
||||
for modelID, registration := range r.models {
|
||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
invalidated = true
|
||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
if invalidated {
|
||||
r.invalidateAvailableModelsCacheLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||
@@ -1152,8 +1264,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
// - string: The model ID of the first available model, or empty string if none available
|
||||
// - error: An error if no models are available
|
||||
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// Get all available models for this handler type
|
||||
models := r.GetAvailableModels(handlerType)
|
||||
@@ -1213,13 +1323,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
||||
// Prefer client's own model info to preserve original type/owned_by
|
||||
if clientInfos != nil {
|
||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||
result = append(result, info)
|
||||
result = append(result, cloneModelInfo(info))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Fallback to global registry (for backwards compatibility)
|
||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||
result = append(result, reg.Info)
|
||||
result = append(result, cloneModelInfo(reg.Info))
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
54
internal/registry/model_registry_cache_test.go
Normal file
54
internal/registry/model_registry_cache_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package registry
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||
|
||||
first := r.GetAvailableModels("openai")
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(first))
|
||||
}
|
||||
first[0]["id"] = "mutated"
|
||||
first[0]["display_name"] = "Mutated"
|
||||
|
||||
second := r.GetAvailableModels("openai")
|
||||
if got := second[0]["id"]; got != "m1" {
|
||||
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
|
||||
}
|
||||
if got := second[0]["display_name"]; got != "Model One" {
|
||||
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||
|
||||
models := r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(models))
|
||||
}
|
||||
if got := models[0]["display_name"]; got != "Model One" {
|
||||
t.Fatalf("expected initial display_name Model One, got %v", got)
|
||||
}
|
||||
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
|
||||
models = r.GetAvailableModels("openai")
|
||||
if got := models[0]["display_name"]; got != "Model One Updated" {
|
||||
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
|
||||
}
|
||||
|
||||
r.SuspendClientModel("client-1", "m1", "manual")
|
||||
models = r.GetAvailableModels("openai")
|
||||
if len(models) != 0 {
|
||||
t.Fatalf("expected no available models after suspension, got %d", len(models))
|
||||
}
|
||||
|
||||
r.ResumeClientModel("client-1", "m1")
|
||||
models = r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected model to reappear after resume, got %d", len(models))
|
||||
}
|
||||
}
|
||||
149
internal/registry/model_registry_safety_test.go
Normal file
149
internal/registry/model_registry_safety_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetModelInfoReturnsClone(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetModelInfo("m1", "gemini")
|
||||
if first == nil {
|
||||
t.Fatal("expected model info")
|
||||
}
|
||||
first.DisplayName = "mutated"
|
||||
first.Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetModelInfo("m1", "gemini")
|
||||
if second.DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
|
||||
}
|
||||
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelsForClientReturnsClones(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetModelsForClient("client-1")
|
||||
if len(first) != 1 || first[0] == nil {
|
||||
t.Fatalf("expected one model, got %+v", first)
|
||||
}
|
||||
first[0].DisplayName = "mutated"
|
||||
first[0].Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetModelsForClient("client-1")
|
||||
if len(second) != 1 || second[0] == nil {
|
||||
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||
}
|
||||
if second[0].DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||
}
|
||||
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||
}})
|
||||
|
||||
first := r.GetAvailableModelsByProvider("gemini")
|
||||
if len(first) != 1 || first[0] == nil {
|
||||
t.Fatalf("expected one model, got %+v", first)
|
||||
}
|
||||
first[0].DisplayName = "mutated"
|
||||
first[0].Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := r.GetAvailableModelsByProvider("gemini")
|
||||
if len(second) != 1 || second[0] == nil {
|
||||
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||
}
|
||||
if second[0].DisplayName != "Model One" {
|
||||
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||
}
|
||||
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
|
||||
r.SetModelQuotaExceeded("client-1", "m1")
|
||||
if models := r.GetAvailableModels("openai"); len(models) != 1 {
|
||||
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
|
||||
}
|
||||
|
||||
r.mutex.Lock()
|
||||
quotaTime := time.Now().Add(-6 * time.Minute)
|
||||
r.models["m1"].QuotaExceededClients["client-1"] = "aTime
|
||||
r.mutex.Unlock()
|
||||
|
||||
r.CleanupExpiredQuotas()
|
||||
|
||||
if count := r.GetModelCount("m1"); count != 1 {
|
||||
t.Fatalf("expected model count 1 after cleanup, got %d", count)
|
||||
}
|
||||
models := r.GetAvailableModels("openai")
|
||||
if len(models) != 1 {
|
||||
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
|
||||
}
|
||||
if got := models[0]["id"]; got != "m1" {
|
||||
t.Fatalf("expected model id m1, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
r.RegisterClient("client-1", "openai", []*ModelInfo{{
|
||||
ID: "m1",
|
||||
DisplayName: "Model One",
|
||||
SupportedParameters: []string{"temperature", "top_p"},
|
||||
}})
|
||||
|
||||
first := r.GetAvailableModels("openai")
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("expected one model, got %d", len(first))
|
||||
}
|
||||
params, ok := first[0]["supported_parameters"].([]string)
|
||||
if !ok || len(params) != 2 {
|
||||
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
|
||||
}
|
||||
params[0] = "mutated"
|
||||
|
||||
second := r.GetAvailableModels("openai")
|
||||
params, ok = second[0]["supported_parameters"].([]string)
|
||||
if !ok || len(params) != 2 || params[0] != "temperature" {
|
||||
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
|
||||
first := LookupModelInfo("glm-4.6")
|
||||
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
|
||||
t.Fatalf("expected static model with thinking levels, got %+v", first)
|
||||
}
|
||||
first.Thinking.Levels[0] = "mutated"
|
||||
|
||||
second := LookupModelInfo("glm-4.6")
|
||||
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
|
||||
t.Fatalf("expected static lookup clone, got %+v", second)
|
||||
}
|
||||
}
|
||||
372
internal/registry/model_updater.go
Normal file
372
internal/registry/model_updater.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
modelsFetchTimeout = 30 * time.Second
|
||||
modelsRefreshInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
var modelsURLs = []string{
|
||||
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
|
||||
"https://models.router-for.me/models.json",
|
||||
}
|
||||
|
||||
//go:embed models/models.json
|
||||
var embeddedModelsJSON []byte
|
||||
|
||||
type modelStore struct {
|
||||
mu sync.RWMutex
|
||||
data *staticModelsJSON
|
||||
}
|
||||
|
||||
var modelsCatalogStore = &modelStore{}
|
||||
|
||||
var updaterOnce sync.Once
|
||||
|
||||
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
|
||||
// changedProviders contains the provider names whose model definitions changed.
|
||||
type ModelRefreshCallback func(changedProviders []string)
|
||||
|
||||
var (
|
||||
refreshCallbackMu sync.Mutex
|
||||
refreshCallback ModelRefreshCallback
|
||||
pendingRefreshChanges []string
|
||||
)
|
||||
|
||||
// SetModelRefreshCallback registers a callback that is invoked when startup or
|
||||
// periodic model refresh detects changes. Only one callback is supported;
|
||||
// subsequent calls replace the previous callback.
|
||||
func SetModelRefreshCallback(cb ModelRefreshCallback) {
|
||||
refreshCallbackMu.Lock()
|
||||
refreshCallback = cb
|
||||
var pending []string
|
||||
if cb != nil && len(pendingRefreshChanges) > 0 {
|
||||
pending = append([]string(nil), pendingRefreshChanges...)
|
||||
pendingRefreshChanges = nil
|
||||
}
|
||||
refreshCallbackMu.Unlock()
|
||||
|
||||
if cb != nil && len(pending) > 0 {
|
||||
cb(pending)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Load embedded data as fallback on startup.
|
||||
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
|
||||
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// StartModelsUpdater starts a background updater that fetches models
|
||||
// immediately on startup and then refreshes the model catalog every 3 hours.
|
||||
// Safe to call multiple times; only one updater will run.
|
||||
func StartModelsUpdater(ctx context.Context) {
|
||||
updaterOnce.Do(func() {
|
||||
go runModelsUpdater(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func runModelsUpdater(ctx context.Context) {
|
||||
tryStartupRefresh(ctx)
|
||||
periodicRefresh(ctx)
|
||||
}
|
||||
|
||||
func periodicRefresh(ctx context.Context) {
|
||||
ticker := time.NewTicker(modelsRefreshInterval)
|
||||
defer ticker.Stop()
|
||||
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
tryPeriodicRefresh(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryPeriodicRefresh fetches models from remote, compares with the current
|
||||
// catalog, and notifies the registered callback if any provider changed.
|
||||
func tryPeriodicRefresh(ctx context.Context) {
|
||||
tryRefreshModels(ctx, "periodic model refresh")
|
||||
}
|
||||
|
||||
// tryStartupRefresh fetches models from remote in the background during
|
||||
// process startup. It uses the same change detection as periodic refresh so
|
||||
// existing auth registrations can be updated after the callback is registered.
|
||||
func tryStartupRefresh(ctx context.Context) {
|
||||
tryRefreshModels(ctx, "startup model refresh")
|
||||
}
|
||||
|
||||
func tryRefreshModels(ctx context.Context, label string) {
|
||||
oldData := getModels()
|
||||
|
||||
parsed, url := fetchModelsFromRemote(ctx)
|
||||
if parsed == nil {
|
||||
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
|
||||
return
|
||||
}
|
||||
|
||||
// Detect changes before updating store.
|
||||
changed := detectChangedProviders(oldData, parsed)
|
||||
|
||||
// Update store with new data regardless.
|
||||
modelsCatalogStore.mu.Lock()
|
||||
modelsCatalogStore.data = parsed
|
||||
modelsCatalogStore.mu.Unlock()
|
||||
|
||||
if len(changed) == 0 {
|
||||
log.Infof("%s completed from %s, no changes detected", label, url)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
|
||||
notifyModelRefresh(changed)
|
||||
}
|
||||
|
||||
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
|
||||
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
|
||||
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
|
||||
client := &http.Client{Timeout: modelsFetchTimeout}
|
||||
for _, url := range modelsURLs {
|
||||
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
|
||||
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
log.Debugf("models fetch request creation failed for %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
log.Debugf("models fetch failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
resp.Body.Close()
|
||||
cancel()
|
||||
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("models fetch read error from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var parsed staticModelsJSON
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
log.Warnf("models parse failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
if err := validateModelsCatalog(&parsed); err != nil {
|
||||
log.Warnf("models validate failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
return &parsed, url
|
||||
}
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// detectChangedProviders compares two model catalogs and returns provider names
|
||||
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
|
||||
// under a single "codex" provider.
|
||||
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
||||
if oldData == nil || newData == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type section struct {
|
||||
provider string
|
||||
oldList []*ModelInfo
|
||||
newList []*ModelInfo
|
||||
}
|
||||
|
||||
sections := []section{
|
||||
{"claude", oldData.Claude, newData.Claude},
|
||||
{"gemini", oldData.Gemini, newData.Gemini},
|
||||
{"vertex", oldData.Vertex, newData.Vertex},
|
||||
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
|
||||
{"aistudio", oldData.AIStudio, newData.AIStudio},
|
||||
{"codex", oldData.CodexFree, newData.CodexFree},
|
||||
{"codex", oldData.CodexTeam, newData.CodexTeam},
|
||||
{"codex", oldData.CodexPlus, newData.CodexPlus},
|
||||
{"codex", oldData.CodexPro, newData.CodexPro},
|
||||
{"qwen", oldData.Qwen, newData.Qwen},
|
||||
{"iflow", oldData.IFlow, newData.IFlow},
|
||||
{"kimi", oldData.Kimi, newData.Kimi},
|
||||
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(sections))
|
||||
var changed []string
|
||||
for _, s := range sections {
|
||||
if seen[s.provider] {
|
||||
continue
|
||||
}
|
||||
if modelSectionChanged(s.oldList, s.newList) {
|
||||
changed = append(changed, s.provider)
|
||||
seen[s.provider] = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// modelSectionChanged reports whether two model slices differ.
|
||||
func modelSectionChanged(a, b []*ModelInfo) bool {
|
||||
if len(a) != len(b) {
|
||||
return true
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return false
|
||||
}
|
||||
aj, err1 := json.Marshal(a)
|
||||
bj, err2 := json.Marshal(b)
|
||||
if err1 != nil || err2 != nil {
|
||||
return true
|
||||
}
|
||||
return string(aj) != string(bj)
|
||||
}
|
||||
|
||||
func notifyModelRefresh(changedProviders []string) {
|
||||
if len(changedProviders) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
refreshCallbackMu.Lock()
|
||||
cb := refreshCallback
|
||||
if cb == nil {
|
||||
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
|
||||
refreshCallbackMu.Unlock()
|
||||
return
|
||||
}
|
||||
refreshCallbackMu.Unlock()
|
||||
cb(changedProviders)
|
||||
}
|
||||
|
||||
func mergeProviderNames(existing, incoming []string) []string {
|
||||
if len(incoming) == 0 {
|
||||
return existing
|
||||
}
|
||||
seen := make(map[string]struct{}, len(existing)+len(incoming))
|
||||
merged := make([]string, 0, len(existing)+len(incoming))
|
||||
for _, provider := range existing {
|
||||
name := strings.ToLower(strings.TrimSpace(provider))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
for _, provider := range incoming {
|
||||
name := strings.ToLower(strings.TrimSpace(provider))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func loadModelsFromBytes(data []byte, source string) error {
|
||||
var parsed staticModelsJSON
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
return fmt.Errorf("%s: decode models catalog: %w", source, err)
|
||||
}
|
||||
if err := validateModelsCatalog(&parsed); err != nil {
|
||||
return fmt.Errorf("%s: validate models catalog: %w", source, err)
|
||||
}
|
||||
|
||||
modelsCatalogStore.mu.Lock()
|
||||
modelsCatalogStore.data = &parsed
|
||||
modelsCatalogStore.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func getModels() *staticModelsJSON {
|
||||
modelsCatalogStore.mu.RLock()
|
||||
defer modelsCatalogStore.mu.RUnlock()
|
||||
return modelsCatalogStore.data
|
||||
}
|
||||
|
||||
func validateModelsCatalog(data *staticModelsJSON) error {
|
||||
if data == nil {
|
||||
return fmt.Errorf("catalog is nil")
|
||||
}
|
||||
|
||||
requiredSections := []struct {
|
||||
name string
|
||||
models []*ModelInfo
|
||||
}{
|
||||
{name: "claude", models: data.Claude},
|
||||
{name: "gemini", models: data.Gemini},
|
||||
{name: "vertex", models: data.Vertex},
|
||||
{name: "gemini-cli", models: data.GeminiCLI},
|
||||
{name: "aistudio", models: data.AIStudio},
|
||||
{name: "codex-free", models: data.CodexFree},
|
||||
{name: "codex-team", models: data.CodexTeam},
|
||||
{name: "codex-plus", models: data.CodexPlus},
|
||||
{name: "codex-pro", models: data.CodexPro},
|
||||
{name: "qwen", models: data.Qwen},
|
||||
{name: "iflow", models: data.IFlow},
|
||||
{name: "kimi", models: data.Kimi},
|
||||
{name: "antigravity", models: data.Antigravity},
|
||||
}
|
||||
|
||||
for _, section := range requiredSections {
|
||||
if err := validateModelSection(section.name, section.models); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateModelSection(section string, models []*ModelInfo) error {
|
||||
if len(models) == 0 {
|
||||
return fmt.Errorf("%s section is empty", section)
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for i, model := range models {
|
||||
if model == nil {
|
||||
return fmt.Errorf("%s[%d] is null", section, i)
|
||||
}
|
||||
modelID := strings.TrimSpace(model.ID)
|
||||
if modelID == "" {
|
||||
return fmt.Errorf("%s[%d] has empty id", section, i)
|
||||
}
|
||||
if _, exists := seen[modelID]; exists {
|
||||
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
|
||||
}
|
||||
seen[modelID] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
2683
internal/registry/models/models.json
Normal file
2683
internal/registry/models/models.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -164,7 +164,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -296,7 +296,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(event.Payload))
|
||||
return false
|
||||
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||
}
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for AI Studio).
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -43,7 +42,6 @@ const (
|
||||
antigravityCountTokensPath = "/v1internal:countTokens"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
||||
@@ -55,78 +53,8 @@ const (
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||
antigravityPrimaryModelsCache struct {
|
||||
mu sync.RWMutex
|
||||
models []*registry.ModelInfo
|
||||
}
|
||||
)
|
||||
|
||||
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*registry.ModelInfo, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, cloneAntigravityModelInfo(model))
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||
if model == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *model
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
thinkingClone := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
clone.Thinking = &thinkingClone
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||
cloned := cloneAntigravityModels(models)
|
||||
if len(cloned) == 0 {
|
||||
return false
|
||||
}
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = cloned
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
antigravityPrimaryModelsCache.mu.RLock()
|
||||
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||
return cloned
|
||||
}
|
||||
|
||||
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
models := loadAntigravityPrimaryModels()
|
||||
if len(models) > 0 {
|
||||
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
type AntigravityExecutor struct {
|
||||
cfg *config.Config
|
||||
@@ -380,7 +308,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -584,7 +512,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
@@ -763,31 +691,42 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
}
|
||||
|
||||
partsJSON, _ := json.Marshal(parts)
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
|
||||
updatedTemplate, _ := sjson.SetRawBytes([]byte(responseTemplate), "candidates.0.content.parts", partsJSON)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
if role != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.content.role", role)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
if finishReason != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.finishReason", finishReason)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
if modelVersion != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "modelVersion", modelVersion)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
if responseID != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "responseId", responseID)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
if usageRaw != "" {
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
|
||||
updatedTemplate, _ = sjson.SetRawBytes([]byte(responseTemplate), "usageMetadata", []byte(usageRaw))
|
||||
responseTemplate = string(updatedTemplate)
|
||||
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.promptTokenCount", 0)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.candidatesTokenCount", 0)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.totalTokenCount", 0)
|
||||
responseTemplate = string(updatedTemplate)
|
||||
}
|
||||
|
||||
output := `{"response":{},"traceId":""}`
|
||||
output, _ = sjson.SetRaw(output, "response", responseTemplate)
|
||||
updatedOutput, _ := sjson.SetRawBytes([]byte(output), "response", []byte(responseTemplate))
|
||||
output = string(updatedOutput)
|
||||
if traceID != "" {
|
||||
output, _ = sjson.Set(output, "traceId", traceID)
|
||||
updatedOutput, _ = sjson.SetBytes([]byte(output), "traceId", traceID)
|
||||
output = string(updatedOutput)
|
||||
}
|
||||
return []byte(output)
|
||||
}
|
||||
@@ -952,12 +891,12 @@ attemptLoop:
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -1115,7 +1054,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
@@ -1150,168 +1089,6 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAntigravityModels retrieves available models using the supplied auth.
|
||||
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
|
||||
var payload []byte
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||
}
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
payload = []byte(`{}`)
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
if host := resolveHost(baseURL); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
modelConfig := registry.GetAntigravityModelConfig()
|
||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||
for originalName, modelData := range result.Map() {
|
||||
modelID := strings.TrimSpace(originalName)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
switch modelID {
|
||||
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||
continue
|
||||
}
|
||||
modelCfg := modelConfig[modelID]
|
||||
|
||||
// Extract displayName from upstream response, fallback to modelID
|
||||
displayName := modelData.Get("displayName").String()
|
||||
if displayName == "" {
|
||||
displayName = modelID
|
||||
}
|
||||
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: modelID,
|
||||
Name: modelID,
|
||||
Description: displayName,
|
||||
DisplayName: displayName,
|
||||
Version: modelID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: antigravityAuthType,
|
||||
Type: antigravityAuthType,
|
||||
}
|
||||
|
||||
// Build input modalities from upstream capability flags.
|
||||
inputModalities := []string{"TEXT"}
|
||||
if modelData.Get("supportsImages").Bool() {
|
||||
inputModalities = append(inputModalities, "IMAGE")
|
||||
}
|
||||
if modelData.Get("supportsVideo").Bool() {
|
||||
inputModalities = append(inputModalities, "VIDEO")
|
||||
}
|
||||
modelInfo.SupportedInputModalities = inputModalities
|
||||
modelInfo.SupportedOutputModalities = []string{"TEXT"}
|
||||
|
||||
// Token limits from upstream.
|
||||
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||
modelInfo.InputTokenLimit = int(maxTok)
|
||||
}
|
||||
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||
modelInfo.OutputTokenLimit = int(maxOut)
|
||||
}
|
||||
|
||||
// Supported generation methods (Gemini v1beta convention).
|
||||
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
|
||||
|
||||
// Look up Thinking support from static config using upstream model name.
|
||||
if modelCfg != nil {
|
||||
if modelCfg.Thinking != nil {
|
||||
modelInfo.Thinking = modelCfg.Thinking
|
||||
}
|
||||
if modelCfg.MaxCompletionTokens > 0 {
|
||||
modelInfo.MaxCompletionTokens = modelCfg.MaxCompletionTokens
|
||||
}
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
storeAntigravityPrimaryModels(models)
|
||||
return models
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||
@@ -1499,19 +1276,20 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
|
||||
// if useAntigravitySchema {
|
||||
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.role", "user")
|
||||
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
|
||||
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
// for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
// payloadStr, _ = sjson.SetRawBytes([]byte(payloadStr), "request.systemInstruction.parts.-1", []byte(partResult.Raw))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
payloadStr = string(updated)
|
||||
} else {
|
||||
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
@@ -1733,8 +1511,9 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
||||
}
|
||||
|
||||
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||
template := payload
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "userAgent", "antigravity")
|
||||
|
||||
isImageModel := strings.Contains(modelName, "image")
|
||||
|
||||
@@ -1744,28 +1523,28 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
} else {
|
||||
reqType = "agent"
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestType", reqType)
|
||||
template, _ = sjson.SetBytes(template, "requestType", reqType)
|
||||
|
||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||
if projectID != "" {
|
||||
template, _ = sjson.Set(template, "project", projectID)
|
||||
template, _ = sjson.SetBytes(template, "project", projectID)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||
template, _ = sjson.SetBytes(template, "project", generateProjectID())
|
||||
}
|
||||
|
||||
if isImageModel {
|
||||
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
|
||||
template, _ = sjson.SetBytes(template, "requestId", generateImageGenRequestID())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
template, _ = sjson.SetBytes(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.SetBytes(template, "request.sessionId", generateStableSessionID(payload))
|
||||
}
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
|
||||
template, _ = sjson.Delete(template, "toolConfig")
|
||||
template, _ = sjson.DeleteBytes(template, "request.safetySettings")
|
||||
if toolConfig := gjson.GetBytes(template, "toolConfig"); toolConfig.Exists() && !gjson.GetBytes(template, "request.toolConfig").Exists() {
|
||||
template, _ = sjson.SetRawBytes(template, "request.toolConfig", []byte(toolConfig.Raw))
|
||||
template, _ = sjson.DeleteBytes(template, "toolConfig")
|
||||
}
|
||||
return []byte(template)
|
||||
return template
|
||||
}
|
||||
|
||||
func generateRequestID() string {
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = nil
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
seed := []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||
t.Fatal("expected non-empty model list to update primary cache")
|
||||
}
|
||||
|
||||
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||
}
|
||||
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||
ID: "gpt-5",
|
||||
DisplayName: "GPT-5",
|
||||
SupportedGenerationMethods: []string{"generateContent"},
|
||||
SupportedParameters: []string{"temperature"},
|
||||
Thinking: ®istry.ThinkingSupport{
|
||||
Levels: []string{"high"},
|
||||
},
|
||||
}}); !updated {
|
||||
t.Fatal("expected model cache update")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one cached model, got %d", len(got))
|
||||
}
|
||||
got[0].ID = "mutated-id"
|
||||
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||
}
|
||||
if len(got[0].SupportedParameters) > 0 {
|
||||
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||
}
|
||||
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||
got[0].Thinking.Levels[0] = "mutated-level"
|
||||
}
|
||||
|
||||
again := loadAntigravityPrimaryModels()
|
||||
if len(again) != 1 {
|
||||
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||
}
|
||||
if again[0].ID != "gpt-5" {
|
||||
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||
}
|
||||
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||
}
|
||||
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||
}
|
||||
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||
}
|
||||
}
|
||||
@@ -255,7 +255,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
data,
|
||||
¶m,
|
||||
)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -443,7 +443,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
¶m,
|
||||
)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
@@ -561,7 +561,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
@@ -1260,12 +1260,18 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
|
||||
partJSON := part.Raw
|
||||
if !part.Get("cache_control").Exists() {
|
||||
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
|
||||
updated, _ := sjson.SetBytes([]byte(partJSON), "cache_control.type", "ephemeral")
|
||||
partJSON = string(updated)
|
||||
}
|
||||
result += "," + partJSON
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if system.Type == gjson.String && system.String() != "" {
|
||||
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
|
||||
updated, _ := sjson.SetBytes([]byte(partJSON), "text", system.String())
|
||||
partJSON = string(updated)
|
||||
result += "," + partJSON
|
||||
}
|
||||
result += "]"
|
||||
|
||||
@@ -1485,25 +1491,27 @@ func countCacheControlsMap(root map[string]any) int {
|
||||
return count
|
||||
}
|
||||
|
||||
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
|
||||
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
|
||||
ccRaw, exists := obj["cache_control"]
|
||||
if !exists {
|
||||
return
|
||||
return false
|
||||
}
|
||||
cc, ok := asObject(ccRaw)
|
||||
if !ok {
|
||||
*seen5m = true
|
||||
return
|
||||
return false
|
||||
}
|
||||
ttlRaw, ttlExists := cc["ttl"]
|
||||
ttl, ttlIsString := ttlRaw.(string)
|
||||
if !ttlExists || !ttlIsString || ttl != "1h" {
|
||||
*seen5m = true
|
||||
return
|
||||
return false
|
||||
}
|
||||
if *seen5m {
|
||||
delete(cc, "ttl")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func findLastCacheControlIndex(arr []any) int {
|
||||
@@ -1599,11 +1607,14 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
}
|
||||
|
||||
seen5m := false
|
||||
modified := false
|
||||
|
||||
if tools, ok := asArray(root["tools"]); ok {
|
||||
for _, tool := range tools {
|
||||
if obj, ok := asObject(tool); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1611,7 +1622,9 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
if system, ok := asArray(root["system"]); ok {
|
||||
for _, item := range system {
|
||||
if obj, ok := asObject(item); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1628,12 +1641,17 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
}
|
||||
for _, item := range content {
|
||||
if obj, ok := asObject(item); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
if normalizeTTLForBlock(obj, &seen5m) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return payload
|
||||
}
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
|
||||
@@ -369,6 +369,19 @@ func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
|
||||
// Payload where no TTL normalization is needed (all blocks use 1h with no
|
||||
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
|
||||
// that json.Marshal would escape to \u003c etc., altering byte identity.
|
||||
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"<system-reminder>foo & bar</system-reminder>","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
|
||||
out := normalizeCacheControlTTL(payload)
|
||||
|
||||
if !bytes.Equal(out, payload) {
|
||||
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
@@ -829,8 +842,8 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
// Inject Accept-Encoding via the custom header attribute mechanism.
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
@@ -967,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
|
||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 1: String system prompt is preserved and converted to a content block
|
||||
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
|
||||
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
system := gjson.GetBytes(out, "system")
|
||||
if !system.IsArray() {
|
||||
t.Fatalf("system should be an array, got %s", system.Type)
|
||||
}
|
||||
|
||||
blocks := system.Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
|
||||
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
|
||||
}
|
||||
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
|
||||
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
|
||||
}
|
||||
if blocks[2].Get("text").String() != "You are a helpful assistant." {
|
||||
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
|
||||
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 2: Strict mode drops the string system prompt
|
||||
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
|
||||
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, true)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 2 {
|
||||
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 3: Empty string system prompt does not produce a spurious block
|
||||
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
|
||||
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 2 {
|
||||
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 4: Array system prompt is unaffected by the string handling
|
||||
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
|
||||
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
if blocks[2].Get("text").String() != "Be concise." {
|
||||
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 5: Special characters in string system prompt survive conversion
|
||||
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
|
||||
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
|
||||
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true)
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||
@@ -226,7 +226,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, false)
|
||||
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -273,7 +273,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -321,7 +321,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true)
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -387,7 +387,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
@@ -432,7 +432,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
|
||||
@@ -636,7 +636,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool) {
|
||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
@@ -647,7 +647,8 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -31,7 +32,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04"
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
|
||||
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
|
||||
codexResponsesWebsocketHandshakeTO = 30 * time.Second
|
||||
)
|
||||
@@ -57,11 +58,6 @@ type codexWebsocketSession struct {
|
||||
wsURL string
|
||||
authID string
|
||||
|
||||
// connCreateSent tracks whether a `response.create` message has been successfully sent
|
||||
// on the current websocket connection. The upstream expects the first message on each
|
||||
// connection to be `response.create`.
|
||||
connCreateSent bool
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
activeMu sync.Mutex
|
||||
@@ -195,7 +191,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -212,13 +208,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
defer sess.reqMu.Unlock()
|
||||
}
|
||||
|
||||
allowAppend := true
|
||||
if sess != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -280,10 +270,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
// execution session.
|
||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if errDialRetry == nil && connRetry != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -312,7 +299,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
return resp, errSend
|
||||
}
|
||||
}
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
for {
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
@@ -357,7 +343,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: out}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
@@ -400,29 +386,23 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
|
||||
executionSessionID := executionSessionIDFromOptions(opts)
|
||||
var sess *codexWebsocketSession
|
||||
if executionSessionID != "" {
|
||||
sess = e.getOrCreateSession(executionSessionID)
|
||||
sess.reqMu.Lock()
|
||||
if sess != nil {
|
||||
sess.reqMu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
allowAppend := true
|
||||
if sess != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -483,10 +463,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
sess.reqMu.Unlock()
|
||||
return nil, errDialRetry
|
||||
}
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -515,7 +492,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
return nil, errSend
|
||||
}
|
||||
}
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
@@ -616,7 +592,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
line := encodeCodexWebsocketAsSSE(payload)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, ¶m)
|
||||
for i := range chunks {
|
||||
if !send(cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}) {
|
||||
if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) {
|
||||
terminateReason = "context_done"
|
||||
terminateErr = ctx.Err()
|
||||
return
|
||||
@@ -657,31 +633,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con
|
||||
return conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
|
||||
func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte {
|
||||
func buildCodexWebsocketRequestBody(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns.
|
||||
// The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation).
|
||||
// Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive.
|
||||
//
|
||||
// NOTE: The upstream expects the first websocket event on each connection to be `response.create`,
|
||||
// so we only use `response.append` after we have initialized the current connection.
|
||||
if allowAppend {
|
||||
if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" {
|
||||
inputNode := gjson.GetBytes(body, "input")
|
||||
wsReqBody := []byte(`{}`)
|
||||
wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append")
|
||||
if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" {
|
||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw))
|
||||
return wsReqBody
|
||||
}
|
||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]"))
|
||||
return wsReqBody
|
||||
}
|
||||
}
|
||||
|
||||
// Match codex-rs websocket v2 semantics: every request is `response.create`.
|
||||
// Incremental follow-up turns continue on the same websocket using
|
||||
// `previous_response_id` + incremental `input`, not `response.append`.
|
||||
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
|
||||
if errSet == nil && len(wsReqBody) > 0 {
|
||||
return wsReqBody
|
||||
@@ -725,21 +684,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession,
|
||||
}
|
||||
}
|
||||
|
||||
func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) {
|
||||
if sess == nil || conn == nil || len(payload) == 0 {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||||
return
|
||||
}
|
||||
|
||||
sess.connMu.Lock()
|
||||
if sess.conn == conn {
|
||||
sess.connCreateSent = true
|
||||
}
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
|
||||
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -762,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
||||
return dialer
|
||||
}
|
||||
|
||||
parsedURL, errParse := url.Parse(proxyURL)
|
||||
setting, errParse := proxyutil.Parse(proxyURL)
|
||||
if errParse != nil {
|
||||
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse)
|
||||
log.Errorf("codex websockets executor: %v", errParse)
|
||||
return dialer
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
switch setting.Mode {
|
||||
case proxyutil.ModeDirect:
|
||||
dialer.Proxy = nil
|
||||
return dialer
|
||||
case proxyutil.ModeProxy:
|
||||
default:
|
||||
return dialer
|
||||
}
|
||||
|
||||
switch setting.URL.Scheme {
|
||||
case "socks5":
|
||||
var proxyAuth *proxy.Auth
|
||||
if parsedURL.User != nil {
|
||||
username := parsedURL.User.Username()
|
||||
password, _ := parsedURL.User.Password()
|
||||
if setting.URL.User != nil {
|
||||
username := setting.URL.User.Username()
|
||||
password, _ := setting.URL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
|
||||
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return dialer
|
||||
@@ -786,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
||||
return socksDialer.Dial(network, addr)
|
||||
}
|
||||
case "http", "https":
|
||||
dialer.Proxy = http.ProxyURL(parsedURL)
|
||||
dialer.Proxy = http.ProxyURL(setting.URL)
|
||||
default:
|
||||
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
|
||||
}
|
||||
|
||||
return dialer
|
||||
@@ -844,7 +797,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
||||
return rawJSON, headers
|
||||
}
|
||||
|
||||
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string) http.Header {
|
||||
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
|
||||
if headers == nil {
|
||||
headers = http.Header{}
|
||||
}
|
||||
@@ -857,7 +810,8 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-beta-features", "")
|
||||
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
|
||||
@@ -872,7 +826,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
}
|
||||
headers.Set("OpenAI-Beta", betaHeader)
|
||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(headers, ginHeaders, "User-Agent", codexUserAgent)
|
||||
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
isAPIKey := false
|
||||
if auth != nil && auth.Attributes != nil {
|
||||
@@ -900,6 +854,62 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
return headers
|
||||
}
|
||||
|
||||
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
|
||||
if cfg == nil || auth == nil {
|
||||
return "", ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||
}
|
||||
|
||||
func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
||||
if target == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(target.Get(key)) != "" {
|
||||
return
|
||||
}
|
||||
if source != nil {
|
||||
if val := strings.TrimSpace(source.Get(key)); val != "" {
|
||||
target.Set(key, val)
|
||||
return
|
||||
}
|
||||
}
|
||||
if val := strings.TrimSpace(configValue); val != "" {
|
||||
target.Set(key, val)
|
||||
return
|
||||
}
|
||||
if val := strings.TrimSpace(fallbackValue); val != "" {
|
||||
target.Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
||||
if target == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(target.Get(key)) != "" {
|
||||
return
|
||||
}
|
||||
if val := strings.TrimSpace(configValue); val != "" {
|
||||
target.Set(key, val)
|
||||
return
|
||||
}
|
||||
if source != nil {
|
||||
if val := strings.TrimSpace(source.Get(key)); val != "" {
|
||||
target.Set(key, val)
|
||||
return
|
||||
}
|
||||
}
|
||||
if val := strings.TrimSpace(fallbackValue); val != "" {
|
||||
target.Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
type statusErrWithHeaders struct {
|
||||
statusErr
|
||||
headers http.Header
|
||||
@@ -1017,36 +1027,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
|
||||
}
|
||||
}
|
||||
|
||||
func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
||||
done := make(chan struct{})
|
||||
if ctx == nil || conn == nil {
|
||||
return done
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
||||
done := make(chan struct{})
|
||||
if ctx == nil || conn == nil {
|
||||
return done
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
_ = conn.SetReadDeadline(time.Now())
|
||||
}
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
|
||||
if len(opts.Metadata) == 0 {
|
||||
return ""
|
||||
@@ -1120,7 +1100,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
|
||||
sess.conn = conn
|
||||
sess.wsURL = wsURL
|
||||
sess.authID = authID
|
||||
sess.connCreateSent = false
|
||||
sess.readerConn = conn
|
||||
sess.connMu.Unlock()
|
||||
|
||||
@@ -1206,7 +1185,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
|
||||
return
|
||||
}
|
||||
sess.conn = nil
|
||||
sess.connCreateSent = false
|
||||
if sess.readerConn == conn {
|
||||
sess.readerConn = nil
|
||||
}
|
||||
@@ -1273,7 +1251,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
|
||||
authID := sess.authID
|
||||
wsURL := sess.wsURL
|
||||
sess.conn = nil
|
||||
sess.connCreateSent = false
|
||||
if sess.readerConn == conn {
|
||||
sess.readerConn = nil
|
||||
}
|
||||
|
||||
203
internal/runtime/executor/codex_websockets_executor_test.go
Normal file
203
internal/runtime/executor/codex_websockets_executor_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
|
||||
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
|
||||
t.Fatalf("type = %s, want response.create", got)
|
||||
}
|
||||
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
|
||||
t.Fatalf("previous_response_id = %s, want resp-1", got)
|
||||
}
|
||||
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
|
||||
t.Fatalf("input item id mismatch")
|
||||
}
|
||||
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
|
||||
t.Fatalf("unexpected websocket request type: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "my-codex-client/1.0",
|
||||
BetaFeatures: "feature-a,feature-b",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
|
||||
}
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "config-ua",
|
||||
BetaFeatures: "config-beta",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
ctx := contextWithGinHeaders(map[string]string{
|
||||
"User-Agent": "client-ua",
|
||||
"X-Codex-Beta-Features": "client-beta",
|
||||
})
|
||||
headers := http.Header{}
|
||||
headers.Set("User-Agent", "existing-ua")
|
||||
headers.Set("X-Codex-Beta-Features", "existing-beta")
|
||||
|
||||
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
|
||||
|
||||
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
|
||||
}
|
||||
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "config-ua",
|
||||
BetaFeatures: "config-beta",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
ctx := contextWithGinHeaders(map[string]string{
|
||||
"User-Agent": "client-ua",
|
||||
"X-Codex-Beta-Features": "client-beta",
|
||||
})
|
||||
|
||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "config-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "config-ua",
|
||||
BetaFeatures: "config-beta",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Attributes: map[string]string{"api_key": "sk-test"},
|
||||
}
|
||||
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest() error = %v", err)
|
||||
}
|
||||
cfg := &config.Config{
|
||||
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||
UserAgent: "config-ua",
|
||||
BetaFeatures: "config-beta",
|
||||
},
|
||||
}
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{"email": "user@example.com"},
|
||||
}
|
||||
req = req.WithContext(contextWithGinHeaders(map[string]string{
|
||||
"User-Agent": "client-ua",
|
||||
}))
|
||||
|
||||
applyCodexHeaders(req, auth, "oauth-token", true, cfg)
|
||||
|
||||
if got := req.Header.Get("User-Agent"); got != "config-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
||||
}
|
||||
if got := req.Header.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func contextWithGinHeaders(headers map[string]string) context.Context {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginCtx.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
ginCtx.Request.Header = make(http.Header, len(headers))
|
||||
for key, value := range headers {
|
||||
ginCtx.Request.Header.Set(key, value)
|
||||
}
|
||||
return context.WithValue(context.Background(), "gin", ginCtx)
|
||||
}
|
||||
|
||||
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dialer := newProxyAwareWebsocketDialer(
|
||||
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||
)
|
||||
|
||||
if dialer.Proxy != nil {
|
||||
t.Fatal("expected websocket proxy function to be nil for direct mode")
|
||||
}
|
||||
}
|
||||
@@ -224,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -401,14 +401,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -430,12 +430,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
}
|
||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||
|
||||
@@ -544,7 +544,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
@@ -811,18 +811,18 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
||||
|
||||
if !hasInlineData {
|
||||
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
||||
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
|
||||
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||
newPartsJson := `[]`
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
|
||||
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
|
||||
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||
newPartsJson := []byte(`[]`)
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
|
||||
|
||||
parts := contentArray[0].Get("parts").Array()
|
||||
for j := 0; j < len(parts); j++ {
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", newPartsJson)
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -321,12 +321,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -415,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||
@@ -527,18 +527,18 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
||||
|
||||
if !hasInlineData {
|
||||
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
||||
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
|
||||
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||
newPartsJson := `[]`
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
|
||||
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
|
||||
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||
newPartsJson := []byte(`[]`)
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
|
||||
|
||||
parts := contentArray[0].Get("parts").Array()
|
||||
for j := 0; j < len(parts); j++ {
|
||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
|
||||
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", newPartsJson)
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
@@ -524,7 +524,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -636,12 +636,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
action := getVertexAction(baseModel, true)
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||
// Imagen models don't support streaming, skip SSE params
|
||||
@@ -760,12 +760,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -857,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
@@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
baseURL = "https://aiplatform.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
|
||||
|
||||
@@ -941,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
|
||||
@@ -221,13 +221,13 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
|
||||
var param any
|
||||
converted := ""
|
||||
var converted []byte
|
||||
if useResponses && from.String() == "claude" {
|
||||
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
||||
} else {
|
||||
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: converted}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -374,14 +374,14 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
}
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
var chunks [][]byte
|
||||
if useResponses && from.String() == "claude" {
|
||||
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
||||
} else {
|
||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
}
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -522,9 +522,9 @@ func detectLastConversationRole(body []byte) string {
|
||||
}
|
||||
|
||||
switch item.Get("type").String() {
|
||||
case "function_call", "function_call_arguments":
|
||||
case "function_call", "function_call_arguments", "computer_call":
|
||||
return "assistant"
|
||||
case "function_call_output", "function_call_response", "tool_result":
|
||||
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
|
||||
return "tool"
|
||||
}
|
||||
}
|
||||
@@ -653,6 +653,7 @@ func normalizeGitHubCopilotChatTools(body []byte) []byte {
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
||||
body = stripGitHubCopilotResponsesUnsupportedFields(body)
|
||||
input := gjson.GetBytes(body, "input")
|
||||
if input.Exists() {
|
||||
// If input is already a string or array, keep it as-is.
|
||||
@@ -825,6 +826,12 @@ func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
|
||||
// GitHub Copilot /responses rejects service_tier, so always remove it.
|
||||
body, _ = sjson.DeleteBytes(body, "service_tier")
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
@@ -832,6 +839,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
if isGitHubCopilotResponsesBuiltinTool(toolType) {
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
|
||||
continue
|
||||
}
|
||||
// Accept OpenAI format (type="function") and Claude format
|
||||
// (no type field, but has top-level name + input_schema).
|
||||
if toolType != "" && toolType != "function" {
|
||||
@@ -879,6 +890,10 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
}
|
||||
if toolChoice.Type == gjson.JSON {
|
||||
choiceType := toolChoice.Get("type").String()
|
||||
if isGitHubCopilotResponsesBuiltinTool(choiceType) {
|
||||
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(toolChoice.Raw))
|
||||
return body
|
||||
}
|
||||
if choiceType == "function" {
|
||||
name := toolChoice.Get("name").String()
|
||||
if name == "" {
|
||||
@@ -896,6 +911,15 @@ func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
func isGitHubCopilotResponsesBuiltinTool(toolType string) bool {
|
||||
switch strings.TrimSpace(toolType) {
|
||||
case "computer", "computer_use_preview":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func collectTextFromNode(node gjson.Result) string {
|
||||
if !node.Exists() {
|
||||
return ""
|
||||
@@ -953,7 +977,7 @@ type githubCopilotResponsesStreamState struct {
|
||||
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
||||
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) []byte {
|
||||
root := gjson.ParseBytes(data)
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
||||
@@ -1043,10 +1067,10 @@ func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||
}
|
||||
return out
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
|
||||
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &githubCopilotResponsesStreamState{
|
||||
TextBlockIndex: -1,
|
||||
@@ -1068,7 +1092,10 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
}
|
||||
|
||||
event := gjson.GetBytes(payload, "type").String()
|
||||
results := make([]string, 0, 4)
|
||||
results := make([][]byte, 0, 4)
|
||||
appendResult := func(chunk string) {
|
||||
results = append(results, []byte(chunk))
|
||||
}
|
||||
ensureMessageStart := func() {
|
||||
if state.MessageStarted {
|
||||
return
|
||||
@@ -1076,7 +1103,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
|
||||
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
|
||||
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
|
||||
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
|
||||
appendResult("event: message_start\ndata: " + messageStart + "\n\n")
|
||||
state.MessageStarted = true
|
||||
}
|
||||
startTextBlockIfNeeded := func() {
|
||||
@@ -1089,7 +1116,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
}
|
||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
|
||||
state.TextBlockStarted = true
|
||||
}
|
||||
stopTextBlockIfNeeded := func() {
|
||||
@@ -1098,7 +1125,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
|
||||
state.TextBlockStarted = false
|
||||
state.TextBlockIndex = -1
|
||||
}
|
||||
@@ -1128,7 +1155,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
|
||||
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
|
||||
appendResult("event: content_block_delta\ndata: " + contentDelta + "\n\n")
|
||||
}
|
||||
case "response.reasoning_summary_part.added":
|
||||
ensureMessageStart()
|
||||
@@ -1137,7 +1164,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
state.NextContentIndex++
|
||||
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
||||
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
|
||||
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
|
||||
appendResult("event: content_block_start\ndata: " + thinkingStart + "\n\n")
|
||||
case "response.reasoning_summary_text.delta":
|
||||
if state.ReasoningActive {
|
||||
delta := gjson.GetBytes(payload, "delta").String()
|
||||
@@ -1145,14 +1172,14 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
||||
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
|
||||
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
|
||||
appendResult("event: content_block_delta\ndata: " + thinkingDelta + "\n\n")
|
||||
}
|
||||
}
|
||||
case "response.reasoning_summary_part.done":
|
||||
if state.ReasoningActive {
|
||||
thinkingStop := `{"type":"content_block_stop","index":0}`
|
||||
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
|
||||
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
|
||||
appendResult("event: content_block_stop\ndata: " + thinkingStop + "\n\n")
|
||||
state.ReasoningActive = false
|
||||
}
|
||||
case "response.output_item.added":
|
||||
@@ -1180,7 +1207,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
|
||||
case "response.output_item.delta":
|
||||
item := gjson.GetBytes(payload, "item")
|
||||
if item.Get("type").String() != "function_call" {
|
||||
@@ -1200,7 +1227,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
||||
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
|
||||
case "response.function_call_arguments.delta":
|
||||
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
|
||||
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
|
||||
@@ -1217,7 +1244,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
||||
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
|
||||
case "response.output_item.done":
|
||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||
break
|
||||
@@ -1228,7 +1255,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
|
||||
case "response.completed":
|
||||
ensureMessageStart()
|
||||
stopTextBlockIfNeeded()
|
||||
@@ -1252,8 +1279,8 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
||||
if cachedTokens > 0 {
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
|
||||
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
appendResult("event: message_delta\ndata: " + messageDelta + "\n\n")
|
||||
appendResult("event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
state.MessageStopSent = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +132,19 @@ func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_StripsServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"user text","service_tier":"default"}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
|
||||
if gjson.GetBytes(got, "service_tier").Exists() {
|
||||
t.Fatalf("service_tier should be removed, got %s", gjson.GetBytes(got, "service_tier").Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "input").String() != "user text" {
|
||||
t.Fatalf("input = %q, want %q", gjson.GetBytes(got, "input").String(), "user text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||
|
||||
1320
internal/runtime/executor/gitlab_executor.go
Normal file
1320
internal/runtime/executor/gitlab_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
469
internal/runtime/executor/gitlab_executor_test.go
Normal file
469
internal/runtime/executor/gitlab_executor_test.go
Normal file
@@ -0,0 +1,469 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitLabExecutorExecuteUsesChatEndpoint(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != gitLabChatEndpoint {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
_, _ = w.Write([]byte(`"chat response"`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "chat response" {
|
||||
t.Fatalf("expected chat response, got %q", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected resolved model, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteFallsBackToCodeSuggestions(t *testing.T) {
|
||||
chatCalls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case gitLabChatEndpoint:
|
||||
chatCalls++
|
||||
http.Error(w, "feature unavailable", http.StatusForbidden)
|
||||
case gitLabCodeSuggestionsEndpoint:
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"text": "fallback response",
|
||||
}},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"personal_access_token": "glpat-token",
|
||||
"auth_method": "pat",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"write code"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if chatCalls != 1 {
|
||||
t.Fatalf("expected chat endpoint to be tried once, got %d", chatCalls)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "fallback response" {
|
||||
t.Fatalf("expected fallback response, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesAnthropicGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{
|
||||
"model":"gitlab-duo",
|
||||
"messages":[{"role":"user","content":[{"type":"text","text":"list files"}]}],
|
||||
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}],
|
||||
"max_tokens":128
|
||||
}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotModel != "claude-sonnet-4-5" {
|
||||
t.Fatalf("model = %q, want claude-sonnet-4-5", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "content.0.type").String(); got != "tool_use" {
|
||||
t.Fatalf("expected tool_use response, got %q", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "content.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("expected tool name Bash, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from openai gateway\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from openai gateway\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-5-codex",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotModel != "gpt-5-codex" {
|
||||
t.Fatalf("model = %q, want gpt-5-codex", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from openai gateway" {
|
||||
t.Fatalf("expected openai gateway response, got %q payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/oauth/token":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "oauth-refreshed",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"token_type": "Bearer",
|
||||
"scope": "api read_user",
|
||||
"created_at": 1710000000,
|
||||
"expires_in": 3600,
|
||||
})
|
||||
case "/api/v4/code_suggestions/direct_access":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1710003600,
|
||||
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "gitlab-auth.json",
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"oauth_client_secret": "client-secret",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := exec.Refresh(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("Refresh() error = %v", err)
|
||||
}
|
||||
if got := updated.Metadata["access_token"]; got != "oauth-refreshed" {
|
||||
t.Fatalf("expected refreshed access token, got %#v", got)
|
||||
}
|
||||
if got := updated.Metadata["model_name"]; got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected refreshed model metadata, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesCodeSuggestionsSSE(t *testing.T) {
|
||||
var gotAccept, gotStreamingHeader, gotEncoding string
|
||||
var gotStreamFlag bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != gitLabCodeSuggestionsEndpoint {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
gotStreamingHeader = r.Header.Get(gitLabSSEStreamingHeader)
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
gotStreamFlag = gjson.GetBytes(readBody(t, r), "stream").Bool()
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: stream_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"model\":{\"name\":\"claude-sonnet-4-5\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_chunk\n"))
|
||||
_, _ = w.Write([]byte("data: {\"content\":\"hello\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_chunk\n"))
|
||||
_, _ = w.Write([]byte("data: {\"content\":\" world\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: stream_end\n"))
|
||||
_, _ = w.Write([]byte("data: {}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if gotAccept != "text/event-stream" {
|
||||
t.Fatalf("Accept = %q, want text/event-stream", gotAccept)
|
||||
}
|
||||
if gotStreamingHeader != "true" {
|
||||
t.Fatalf("%s = %q, want true", gitLabSSEStreamingHeader, gotStreamingHeader)
|
||||
}
|
||||
if gotEncoding != "identity" {
|
||||
t.Fatalf("Accept-Encoding = %q, want identity", gotEncoding)
|
||||
}
|
||||
if !gotStreamFlag {
|
||||
t.Fatalf("expected upstream request to set stream=true")
|
||||
}
|
||||
if len(lines) < 4 {
|
||||
t.Fatalf("expected translated stream chunks, got %d", len(lines))
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":"hello"`) {
|
||||
t.Fatalf("expected hello delta in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":" world"`) {
|
||||
t.Fatalf("expected world delta in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
last := lines[len(lines)-1]
|
||||
if last != "data: [DONE]" && !strings.Contains(last, `"finish_reason":"stop"`) {
|
||||
t.Fatalf("expected stream terminator, got %q", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
|
||||
chatCalls := 0
|
||||
streamCalls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case gitLabCodeSuggestionsEndpoint:
|
||||
streamCalls++
|
||||
http.Error(w, "feature unavailable", http.StatusForbidden)
|
||||
case gitLabChatEndpoint:
|
||||
chatCalls++
|
||||
_, _ = w.Write([]byte(`"chat fallback response"`))
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if streamCalls != 1 {
|
||||
t.Fatalf("expected streaming endpoint once, got %d", streamCalls)
|
||||
}
|
||||
if chatCalls != 1 {
|
||||
t.Fatalf("expected chat fallback once, got %d", chatCalls)
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":"chat fallback response"`) {
|
||||
t.Fatalf("expected fallback content in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
var gotPath string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: message_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from gateway\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_stop\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"max_tokens":64}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
||||
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func collectStreamLines(t *testing.T, result *cliproxyexecutor.StreamResult) []string {
|
||||
t.Helper()
|
||||
lines := make([]string, 0, 8)
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||
}
|
||||
lines = append(lines, string(chunk.Payload))
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func readBody(t *testing.T, r *http.Request) []byte {
|
||||
t.Helper()
|
||||
defer func() { _ = r.Body.Close() }()
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error = %v", err)
|
||||
}
|
||||
return body
|
||||
}
|
||||
@@ -169,7 +169,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -281,7 +281,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
@@ -315,7 +315,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key.
|
||||
|
||||
@@ -161,7 +161,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -271,12 +271,12 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
|
||||
@@ -89,6 +89,13 @@ var endpointAliases = map[string]string{
|
||||
"cli": "amazonq",
|
||||
}
|
||||
|
||||
func enqueueTranslatedSSE(out chan<- cliproxyexecutor.StreamChunk, chunk []byte) {
|
||||
if len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: append(bytes.Clone(chunk), '\n', '\n')}
|
||||
}
|
||||
|
||||
// retryConfig holds configuration for socket retry logic.
|
||||
// Based on kiro2Api Python implementation patterns.
|
||||
type retryConfig struct {
|
||||
@@ -2573,9 +2580,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Send tool input as delta
|
||||
@@ -2583,18 +2588,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Close block
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
hasToolUses = true
|
||||
@@ -2664,9 +2665,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
messageStartSent = true
|
||||
}
|
||||
@@ -2916,9 +2915,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
lastReportedOutputTokens = currentOutputTokens
|
||||
@@ -2939,17 +2936,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
continue
|
||||
@@ -2978,18 +2971,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
// Send thinking delta
|
||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
accumulatedThinkingContent.WriteString(thinkingText)
|
||||
}
|
||||
@@ -2998,9 +2987,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isThinkingBlockOpen = false
|
||||
}
|
||||
@@ -3029,17 +3016,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
accumulatedThinkingContent.WriteString(processContent)
|
||||
}
|
||||
@@ -3058,9 +3041,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isThinkingBlockOpen = false
|
||||
}
|
||||
@@ -3071,18 +3052,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
// Send text delta
|
||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
// Close text block before entering thinking
|
||||
@@ -3090,9 +3067,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
@@ -3120,17 +3095,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3158,9 +3129,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
@@ -3171,9 +3140,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Send input_json_delta with the tool input
|
||||
@@ -3186,9 +3153,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3197,9 +3162,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3239,9 +3202,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
@@ -3254,9 +3215,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3264,9 +3223,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Accumulate for token counting
|
||||
@@ -3298,9 +3255,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
isTextBlockOpen = false
|
||||
}
|
||||
@@ -3310,9 +3265,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
if tu.Input != nil {
|
||||
@@ -3323,9 +3276,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3333,9 +3284,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3522,9 +3471,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3609,18 +3556,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
|
||||
// Send message_stop event separately
|
||||
msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent()
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
enqueueTranslatedSSE(out, chunk)
|
||||
}
|
||||
// reporter.publish is called via defer
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -205,6 +205,10 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Request usage data in the final streaming chunk so that token statistics
|
||||
// are captured even when the upstream is an OpenAI-compatible provider.
|
||||
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
@@ -286,7 +290,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
// Pass through translator; it yields one or more chunks for the target schema.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
@@ -326,7 +330,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
|
||||
}
|
||||
|
||||
// Refresh is a no-op for API-key based compatibility providers.
|
||||
|
||||
@@ -2,17 +2,15 @@ package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
|
||||
@@ -111,45 +109,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
// Returns:
|
||||
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
|
||||
func buildProxyTransport(proxyURL string) *http.Transport {
|
||||
if proxyURL == "" {
|
||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("%v", errBuild)
|
||||
return nil
|
||||
}
|
||||
|
||||
parsedURL, errParse := url.Parse(proxyURL)
|
||||
if errParse != nil {
|
||||
log.Errorf("parse proxy URL failed: %v", errParse)
|
||||
return nil
|
||||
}
|
||||
|
||||
var transport *http.Transport
|
||||
|
||||
// Handle different proxy schemes
|
||||
if parsedURL.Scheme == "socks5" {
|
||||
// Configure SOCKS5 proxy with optional authentication
|
||||
var proxyAuth *proxy.Auth
|
||||
if parsedURL.User != nil {
|
||||
username := parsedURL.User.Username()
|
||||
password, _ := parsedURL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return nil
|
||||
}
|
||||
// Set up a custom transport using the SOCKS5 dialer
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
|
||||
// Configure HTTP or HTTPS proxy
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
|
||||
} else {
|
||||
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
return nil
|
||||
}
|
||||
|
||||
return transport
|
||||
}
|
||||
|
||||
30
internal/runtime/executor/proxy_helpers_test.go
Normal file
30
internal/runtime/executor/proxy_helpers_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := newProxyAwareHTTPClient(
|
||||
context.Background(),
|
||||
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||
0,
|
||||
)
|
||||
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
|
||||
}
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("expected direct transport to disable proxy function")
|
||||
}
|
||||
}
|
||||
@@ -305,7 +305,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -421,12 +421,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
@@ -461,7 +461,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
|
||||
@@ -257,7 +257,10 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma
|
||||
if suffixResult.HasSuffix {
|
||||
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
|
||||
} else {
|
||||
config = extractThinkingConfig(body, toFormat)
|
||||
config = extractThinkingConfig(body, fromFormat)
|
||||
if !hasThinkingConfig(config) && fromFormat != toFormat {
|
||||
config = extractThinkingConfig(body, toFormat)
|
||||
}
|
||||
}
|
||||
|
||||
if !hasThinkingConfig(config) {
|
||||
@@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
|
||||
if config.Mode != ModeLevel {
|
||||
return config
|
||||
}
|
||||
if toFormat == "claude" {
|
||||
return config
|
||||
}
|
||||
if !isBudgetCapableProvider(toFormat) {
|
||||
return config
|
||||
}
|
||||
|
||||
55
internal/thinking/apply_user_defined_test.go
Normal file
55
internal/thinking/apply_user_defined_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package thinking_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-user-defined-claude-" + t.Name()
|
||||
modelID := "custom-claude-4-6"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(clientID)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "claude adaptive effort body",
|
||||
model: modelID,
|
||||
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
|
||||
},
|
||||
{
|
||||
name: "suffix level",
|
||||
model: modelID + "(high)",
|
||||
body: []byte(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyThinking() error = %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
|
||||
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
|
||||
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
|
||||
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -39,35 +40,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
var systemInstructionJSON []byte
|
||||
hasSystemInstruction := false
|
||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
partJSON := `{}`
|
||||
partJSON := []byte(`{}`)
|
||||
if systemPrompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", systemPrompt)
|
||||
}
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
|
||||
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", partJSON)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
|
||||
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
|
||||
systemInstructionJSON = []byte(`{"role":"user","parts":[{"text":""}]}`)
|
||||
systemInstructionJSON, _ = sjson.SetBytes(systemInstructionJSON, "parts.0.text", systemResult.String())
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
|
||||
// contents
|
||||
contentsJSON := "[]"
|
||||
contentsJSON := []byte(`[]`)
|
||||
hasContents := false
|
||||
|
||||
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
|
||||
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
|
||||
toolNameByID := make(map[string]string)
|
||||
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
@@ -83,8 +88,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContentJSON := `{"role":"","parts":[]}`
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
|
||||
clientContentJSON := []byte(`{"role":"","parts":[]}`)
|
||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "role", role)
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
@@ -143,15 +148,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
|
||||
// Valid signature, send as thought block
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "thought", true)
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
|
||||
if thinkingText != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
|
||||
}
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
// Skip empty text parts to avoid Gemini API error:
|
||||
@@ -159,9 +164,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if prompt == "" {
|
||||
continue
|
||||
}
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||
// Antigravity API validates signatures, so dummy values are rejected.
|
||||
@@ -170,6 +175,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
argsResult := contentResult.Get("input")
|
||||
functionID := contentResult.Get("id").String()
|
||||
|
||||
if functionID != "" && functionName != "" {
|
||||
toolNameByID[functionID] = functionName
|
||||
}
|
||||
|
||||
// Handle both object and string input formats
|
||||
var argsRaw string
|
||||
if argsResult.IsObject() {
|
||||
@@ -183,138 +192,147 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
|
||||
if argsRaw != "" {
|
||||
partJSON := `{}`
|
||||
partJSON := []byte(`{}`)
|
||||
|
||||
// Use skip_thought_signature_validator for tool calls without valid thinking signature
|
||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||
// and also works for Claude through Antigravity API
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||
} else {
|
||||
// No valid signature - use skip sentinel to bypass validation
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", skipSentinel)
|
||||
}
|
||||
|
||||
if functionID != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.id", functionID)
|
||||
}
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRawBytes(partJSON, "functionCall.args", []byte(argsRaw))
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
funcName, ok := toolNameByID[toolCallID]
|
||||
if !ok {
|
||||
// Fallback: derive a semantic name from the ID by stripping
|
||||
// the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
|
||||
// Only use the raw ID as a last resort when the heuristic produces an empty string.
|
||||
parts := strings.Split(toolCallID, "-")
|
||||
if len(parts) > 2 {
|
||||
funcName = strings.Join(parts[:len(parts)-2], "-")
|
||||
}
|
||||
if funcName == "" {
|
||||
funcName = toolCallID
|
||||
}
|
||||
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
|
||||
}
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
functionResponseJSON := `{}`
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
|
||||
functionResponseJSON := []byte(`{}`)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", funcName)
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
responseData = functionResponseResult.String()
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
nonImageCount := 0
|
||||
lastNonImageRaw := ""
|
||||
filteredJSON := "[]"
|
||||
imagePartsJSON := "[]"
|
||||
filteredJSON := []byte(`[]`)
|
||||
imagePartsJSON := []byte(`[]`)
|
||||
for _, fr := range frResults {
|
||||
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
inlineDataJSON := []byte(`{}`)
|
||||
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := fr.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
imagePartJSON := []byte(`{}`)
|
||||
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
|
||||
continue
|
||||
}
|
||||
|
||||
nonImageCount++
|
||||
lastNonImageRaw = fr.Raw
|
||||
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
||||
filteredJSON, _ = sjson.SetRawBytes(filteredJSON, "-1", []byte(fr.Raw))
|
||||
}
|
||||
|
||||
if nonImageCount == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(lastNonImageRaw))
|
||||
} else if nonImageCount > 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", filteredJSON)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
// Place image data inside functionResponse.parts as inlineData
|
||||
// instead of as sibling parts in the outer content, to avoid
|
||||
// base64 data bloating the text context.
|
||||
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
if gjson.GetBytes(imagePartsJSON, "#").Int() > 0 {
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
inlineDataJSON := []byte(`{}`)
|
||||
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
imagePartJSON := `{}`
|
||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON := "[]"
|
||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
imagePartJSON := []byte(`{}`)
|
||||
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
|
||||
imagePartsJSON := []byte(`[]`)
|
||||
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
|
||||
}
|
||||
} else if functionResponseResult.Raw != "" {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
|
||||
} else {
|
||||
// Content field is missing entirely — .Raw is empty which
|
||||
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetRawBytes(partJSON, "functionResponse", functionResponseJSON)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
||||
sourceResult := contentResult.Get("source")
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineDataJSON := `{}`
|
||||
inlineDataJSON := []byte(`{}`)
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||
}
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
partJSON := []byte(`{}`)
|
||||
partJSON, _ = sjson.SetRawBytes(partJSON, "inlineData", inlineDataJSON)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder parts for 'model' role to ensure thinking block is first
|
||||
if role == "model" {
|
||||
partsResult := gjson.Get(clientContentJSON, "parts")
|
||||
partsResult := gjson.GetBytes(clientContentJSON, "parts")
|
||||
if partsResult.IsArray() {
|
||||
parts := partsResult.Array()
|
||||
var thinkingParts []gjson.Result
|
||||
@@ -336,7 +354,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
for _, p := range otherParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
|
||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -344,33 +362,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// Skip messages with empty parts array to avoid Gemini API error:
|
||||
// "required oneof field 'data' must have one initialized field"
|
||||
partsCheck := gjson.Get(clientContentJSON, "parts")
|
||||
partsCheck := gjson.GetBytes(clientContentJSON, "parts")
|
||||
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
partJSON := `{}`
|
||||
partJSON := []byte(`{}`)
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tools
|
||||
toolsJSON := ""
|
||||
var toolsJSON []byte
|
||||
toolDeclCount := 0
|
||||
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
toolsJSON = `[{"functionDeclarations":[]}]`
|
||||
toolsJSON = []byte(`[{"functionDeclarations":[]}]`)
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
@@ -378,23 +396,23 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
// Sanitize the input schema for Antigravity API compatibility
|
||||
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
for toolKey := range gjson.Parse(tool).Map() {
|
||||
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
|
||||
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
|
||||
for toolKey := range gjson.ParseBytes(tool).Map() {
|
||||
if util.InArray(allowedToolKeys, toolKey) {
|
||||
continue
|
||||
}
|
||||
tool, _ = sjson.Delete(tool, toolKey)
|
||||
tool, _ = sjson.DeleteBytes(tool, toolKey)
|
||||
}
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolDeclCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out := []byte(`{"model":"","request":{"contents":[]}}`)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Inject interleaved thinking hint when both tools and thinking are active
|
||||
hasTools := toolDeclCount > 0
|
||||
@@ -408,27 +426,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
if hasSystemInstruction {
|
||||
// Append hint as a new part to existing system instruction
|
||||
hintPart := `{"text":""}`
|
||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
||||
hintPart := []byte(`{"text":""}`)
|
||||
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
|
||||
} else {
|
||||
// Create new system instruction with hint
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
hintPart := `{"text":""}`
|
||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
||||
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
|
||||
hintPart := []byte(`{"text":""}`)
|
||||
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasSystemInstruction {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
|
||||
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstructionJSON)
|
||||
}
|
||||
if hasContents {
|
||||
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents", contentsJSON)
|
||||
}
|
||||
if toolDeclCount > 0 {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools", toolsJSON)
|
||||
}
|
||||
|
||||
// tool_choice
|
||||
@@ -445,15 +463,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
switch toolChoiceType {
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||
case "any":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
case "tool":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -464,8 +482,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive", "auto":
|
||||
// For adaptive thinking:
|
||||
@@ -477,31 +495,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
if effort == "max" {
|
||||
effort = "high"
|
||||
}
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
}
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", v.Num)
|
||||
}
|
||||
|
||||
outBytes := []byte(out)
|
||||
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
|
||||
out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
|
||||
|
||||
return outBytes
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -365,6 +365,17 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "get_weather-call-123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -382,13 +393,177 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function response conversion
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Error("functionResponse should exist")
|
||||
}
|
||||
if funcResp.Get("id").String() != "get_weather-call-123" {
|
||||
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
|
||||
}
|
||||
if funcResp.Get("name").String() != "get_weather" {
|
||||
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"name": "Glob",
|
||||
"input": {"pattern": "**/*.py"}
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"content": "file1.py\nfile2.py"
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||
"content": "total 10"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp0.Exists() {
|
||||
t.Fatal("first functionResponse should exist")
|
||||
}
|
||||
if got := funcResp0.Get("name").String(); got != "Glob" {
|
||||
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
|
||||
}
|
||||
|
||||
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
|
||||
if !funcResp1.Exists() {
|
||||
t.Fatal("second functionResponse should exist")
|
||||
}
|
||||
if got := funcResp1.Get("name").String(); got != "Bash" {
|
||||
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "Read-1773420180464065165-1327",
|
||||
"name": "Read",
|
||||
"input": {"file_path": "/tmp/test.py"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "Read-1773420180464065165-1327",
|
||||
"content": "file content here"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
if got := funcResp.Get("name").String(); got != "Read" {
|
||||
t.Errorf("Expected name 'Read', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "get_weather-call-123",
|
||||
"content": "22C sunny"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
if got := funcResp.Get("name").String(); got != "get_weather" {
|
||||
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||
"content": "result data"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Fatal("functionResponse should exist")
|
||||
}
|
||||
got := funcResp.Get("name").String()
|
||||
if got == "" {
|
||||
t.Error("functionResponse.name must not be empty")
|
||||
}
|
||||
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
|
||||
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
||||
@@ -1235,64 +1410,3 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t
|
||||
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
effort string
|
||||
expected string
|
||||
}{
|
||||
{"low", "low", "low"},
|
||||
{"medium", "medium", "medium"},
|
||||
{"high", "high", "high"},
|
||||
{"max", "max", "high"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"thinking": {"type": "adaptive"},
|
||||
"output_config": {"effort": "` + tt.effort + `"}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
||||
if !thinkingConfig.Exists() {
|
||||
t.Fatal("thinkingConfig should exist for adaptive thinking")
|
||||
}
|
||||
if thinkingConfig.Get("thinkingLevel").String() != tt.expected {
|
||||
t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String())
|
||||
}
|
||||
if !thinkingConfig.Get("includeThoughts").Bool() {
|
||||
t.Error("includeThoughts should be true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"thinking": {"type": "adaptive"}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
||||
if !thinkingConfig.Exists() {
|
||||
t.Fatal("thinkingConfig should exist for adaptive thinking without effort")
|
||||
}
|
||||
if thinkingConfig.Get("thinkingLevel").String() != "high" {
|
||||
t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String())
|
||||
}
|
||||
if !thinkingConfig.Get("includeThoughts").Bool() {
|
||||
t.Error("includeThoughts should be true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -62,8 +64,8 @@ var toolUseIDCounter uint64
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
|
||||
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
|
||||
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &Params{
|
||||
HasFirstResponse: false,
|
||||
@@ -76,44 +78,44 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
params := (*param).(*Params)
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
output := ""
|
||||
output := make([]byte, 0, 256)
|
||||
// Only send final events if we have actually output content
|
||||
if params.HasContent {
|
||||
appendFinalEvents(params, &output, true)
|
||||
return []string{
|
||||
output + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
|
||||
}
|
||||
output = translatorcommon.AppendSSEEventString(output, "message_stop", `{"type":"message_stop"}`, 3)
|
||||
return [][]byte{output}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
output := ""
|
||||
output := make([]byte, 0, 1024)
|
||||
appendEvent := func(event, payload string) {
|
||||
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
|
||||
}
|
||||
|
||||
// Initialize the streaming session with a message_start event
|
||||
// This is only sent for the very first response chunk to establish the streaming session
|
||||
if !params.HasFirstResponse {
|
||||
output = "event: message_start\n"
|
||||
|
||||
// Create the initial message structure with default values according to Claude Code API specification
|
||||
// This follows the Claude Code API specification for streaming message initialization
|
||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||
messageStartTemplate := []byte(`{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`)
|
||||
|
||||
// Use cpaUsageMetadata within the message_start event for Claude.
|
||||
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
||||
}
|
||||
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
||||
}
|
||||
|
||||
// Override default values with actual response metadata if available from the Gemini CLI response
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
}
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
|
||||
}
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
||||
appendEvent("message_start", string(messageStartTemplate))
|
||||
|
||||
params.HasFirstResponse = true
|
||||
}
|
||||
@@ -143,15 +145,13 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
params.CurrentThinkingText.Reset()
|
||||
}
|
||||
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex)), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.HasContent = true
|
||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||
params.CurrentThinkingText.WriteString(partTextResult.String())
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.HasContent = true
|
||||
} else {
|
||||
// Transition from another state to thinking
|
||||
@@ -162,19 +162,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||
params.ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new thinking content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex))
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.ResponseType = 2 // Set state to thinking
|
||||
params.HasContent = true
|
||||
// Start accumulating thinking text for signature caching
|
||||
@@ -187,9 +182,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Process regular text content (user-visible output)
|
||||
// Continue existing text block if already in content state
|
||||
if params.ResponseType == 1 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.HasContent = true
|
||||
} else {
|
||||
// Transition from another state to text content
|
||||
@@ -200,19 +194,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||
params.ResponseIndex++
|
||||
}
|
||||
if partTextResult.String() != "" {
|
||||
// Start a new text content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex))
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
params.ResponseType = 1 // Set state to content
|
||||
params.HasContent = true
|
||||
}
|
||||
@@ -228,9 +217,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
if params.ResponseType == 3 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||
params.ResponseIndex++
|
||||
params.ResponseType = 0
|
||||
}
|
||||
@@ -244,26 +231,21 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
|
||||
// Close any other existing content block
|
||||
if params.ResponseType != 0 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||
params.ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new tool use content block
|
||||
// This creates the structure for a function call in Claude Code format
|
||||
output = output + "event: content_block_start\n"
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex))
|
||||
data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.SetBytes(data, "content_block.name", fcName)
|
||||
appendEvent("content_block_start", string(data))
|
||||
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
|
||||
appendEvent("content_block_delta", string(data))
|
||||
}
|
||||
params.ResponseType = 3
|
||||
params.HasContent = true
|
||||
@@ -295,10 +277,10 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
appendFinalEvents(params, &output, false)
|
||||
}
|
||||
|
||||
return []string{output}
|
||||
return [][]byte{output}
|
||||
}
|
||||
|
||||
func appendFinalEvents(params *Params, output *string, force bool) {
|
||||
func appendFinalEvents(params *Params, output *[]byte, force bool) {
|
||||
if params.HasSentFinalEvents {
|
||||
return
|
||||
}
|
||||
@@ -313,9 +295,7 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
||||
}
|
||||
|
||||
if params.ResponseType != 0 {
|
||||
*output = *output + "event: content_block_stop\n"
|
||||
*output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
*output = *output + "\n\n\n"
|
||||
*output = translatorcommon.AppendSSEEventString(*output, "content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex), 3)
|
||||
params.ResponseType = 0
|
||||
}
|
||||
|
||||
@@ -328,18 +308,16 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
||||
}
|
||||
}
|
||||
|
||||
*output = *output + "event: message_delta\n"
|
||||
*output = *output + "data: "
|
||||
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
|
||||
delta := []byte(fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens))
|
||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||
if params.CachedTokenCount > 0 {
|
||||
var err error
|
||||
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
||||
delta, err = sjson.SetBytes(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
*output = *output + delta + "\n\n\n"
|
||||
*output = translatorcommon.AppendSSEEventString(*output, "message_delta", string(delta), 3)
|
||||
|
||||
params.HasSentFinalEvents = true
|
||||
}
|
||||
@@ -368,8 +346,8 @@ func resolveStopReason(params *Params) string {
|
||||
// - param: A pointer to a parameter object for the conversion.
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Claude-compatible JSON response.
|
||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Claude-compatible JSON response.
|
||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
_ = originalRequestRawJSON
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
|
||||
@@ -387,15 +365,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
|
||||
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
||||
responseJSON := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "id", root.Get("response.responseId").String())
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.input_tokens", promptTokens)
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.output_tokens", outputTokens)
|
||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||
if cachedTokens > 0 {
|
||||
var err error
|
||||
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||
responseJSON, err = sjson.SetBytes(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||
}
|
||||
@@ -406,7 +384,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
if contentArrayInitialized {
|
||||
return
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
|
||||
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content", []byte("[]"))
|
||||
contentArrayInitialized = true
|
||||
}
|
||||
|
||||
@@ -422,9 +400,9 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
return
|
||||
}
|
||||
ensureContentArray()
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", textBuilder.String())
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
block := []byte(`{"type":"text","text":""}`)
|
||||
block, _ = sjson.SetBytes(block, "text", textBuilder.String())
|
||||
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
|
||||
textBuilder.Reset()
|
||||
}
|
||||
|
||||
@@ -433,12 +411,12 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
return
|
||||
}
|
||||
ensureContentArray()
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
block := []byte(`{"type":"thinking","thinking":""}`)
|
||||
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
||||
if thinkingSignature != "" {
|
||||
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||
block, _ = sjson.SetBytes(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
thinkingSignature = ""
|
||||
}
|
||||
@@ -474,16 +452,16 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
toolIDCounter++
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
|
||||
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(args.Raw))
|
||||
}
|
||||
|
||||
ensureContentArray()
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
|
||||
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", toolBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -507,17 +485,17 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
}
|
||||
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
|
||||
responseJSON, _ = sjson.SetBytes(responseJSON, "stop_reason", stopReason)
|
||||
|
||||
if promptTokens == 0 && outputTokens == 0 {
|
||||
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
||||
responseJSON, _ = sjson.Delete(responseJSON, "usage")
|
||||
responseJSON, _ = sjson.DeleteBytes(responseJSON, "usage")
|
||||
}
|
||||
}
|
||||
|
||||
return responseJSON
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.ClaudeInputTokensJSON(count)
|
||||
}
|
||||
|
||||
@@ -34,10 +34,10 @@ import (
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template := `{"project":"","request":{},"model":""}`
|
||||
templateBytes, _ := sjson.SetRawBytes([]byte(template), "request", rawJSON)
|
||||
templateBytes, _ = sjson.SetBytes(templateBytes, "model", modelName)
|
||||
template = string(templateBytes)
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
|
||||
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
||||
@@ -47,7 +47,8 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||
templateBytes, _ = sjson.SetRawBytes([]byte(template), "request.systemInstruction", []byte(systemInstructionResult.Raw))
|
||||
template = string(templateBytes)
|
||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||
}
|
||||
rawJSON = []byte(template)
|
||||
@@ -138,30 +139,47 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// FunctionCallGroup represents a group of function calls and their responses
|
||||
type FunctionCallGroup struct {
|
||||
ResponsesNeeded int
|
||||
CallNames []string // ordered function call names for backfilling empty response names
|
||||
}
|
||||
|
||||
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
|
||||
// Falls back to a minimal "functionResponse" object when parsing fails.
|
||||
func parseFunctionResponseRaw(response gjson.Result) string {
|
||||
// fallbackName is used when the response's own name is empty.
|
||||
func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string {
|
||||
if response.IsObject() && gjson.Valid(response.Raw) {
|
||||
return response.Raw
|
||||
raw := response.Raw
|
||||
name := response.Get("functionResponse.name").String()
|
||||
if strings.TrimSpace(name) == "" && fallbackName != "" {
|
||||
updated, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
|
||||
raw = string(updated)
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
log.Debugf("parse function response failed, using fallback")
|
||||
funcResp := response.Get("functionResponse")
|
||||
if funcResp.Exists() {
|
||||
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
|
||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
||||
if id := funcResp.Get("id").String(); id != "" {
|
||||
fr, _ = sjson.Set(fr, "functionResponse.id", id)
|
||||
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
name := funcResp.Get("name").String()
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = fallbackName
|
||||
}
|
||||
return fr
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.name", name)
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
||||
if id := funcResp.Get("id").String(); id != "" {
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.id", id)
|
||||
}
|
||||
return string(fr)
|
||||
}
|
||||
|
||||
fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
|
||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
|
||||
return fr
|
||||
useName := fallbackName
|
||||
if useName == "" {
|
||||
useName = "unknown"
|
||||
}
|
||||
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.name", useName)
|
||||
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", response.String())
|
||||
return string(fr)
|
||||
}
|
||||
|
||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
@@ -188,7 +206,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
}
|
||||
|
||||
// Initialize data structures for processing and grouping
|
||||
contentsWrapper := `{"contents":[]}`
|
||||
contentsWrapper := []byte(`{"contents":[]}`)
|
||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||
|
||||
@@ -211,30 +229,26 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
if len(responsePartsInThisContent) > 0 {
|
||||
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||
|
||||
// Check if any pending groups can be satisfied
|
||||
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||
group := pendingGroups[i]
|
||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
// Check if pending groups can be satisfied (FIFO: oldest group first)
|
||||
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
|
||||
group := pendingGroups[0]
|
||||
pendingGroups = pendingGroups[1:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response)
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
}
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
||||
for ri, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
|
||||
}
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
|
||||
// Remove this group as it's been satisfied
|
||||
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||
break
|
||||
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,25 +257,26 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
// If this is a model with function calls, create a new group
|
||||
if role == "model" {
|
||||
functionCallsCount := 0
|
||||
var callNames []string
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
functionCallsCount++
|
||||
callNames = append(callNames, part.Get("functionCall.name").String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if functionCallsCount > 0 {
|
||||
if len(callNames) > 0 {
|
||||
// Add the model content
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse model content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
ResponsesNeeded: functionCallsCount,
|
||||
ResponsesNeeded: len(callNames),
|
||||
CallNames: callNames,
|
||||
}
|
||||
pendingGroups = append(pendingGroups, group)
|
||||
} else {
|
||||
@@ -270,7 +285,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
}
|
||||
} else {
|
||||
// Non-model content (user, etc.)
|
||||
@@ -278,7 +293,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -290,23 +305,22 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response)
|
||||
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
||||
for ri, response := range groupResponses {
|
||||
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
|
||||
}
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the original JSON with the new contents
|
||||
result := input
|
||||
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
|
||||
result, _ := sjson.SetRawBytes([]byte(input), "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
|
||||
|
||||
return result, nil
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
@@ -171,3 +171,257 @@ func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
||||
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) {
|
||||
// When the Amp client sends functionResponse with an empty name,
|
||||
// fixCLIToolResponse should backfill it from the corresponding functionCall.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
name := funcContent.Get("parts.0.functionResponse.name").String()
|
||||
if name != "Bash" {
|
||||
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_BackfillsMultipleEmptyNames(t *testing.T) {
|
||||
// Parallel function calls: both responses have empty names.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
|
||||
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
|
||||
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
parts := funcContent.Get("parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 function response parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
name0 := parts[0].Get("functionResponse.name").String()
|
||||
name1 := parts[1].Get("functionResponse.name").String()
|
||||
if name0 != "Read" {
|
||||
t.Errorf("Expected first response name 'Read', got '%s'", name0)
|
||||
}
|
||||
if name1 != "Grep" {
|
||||
t.Errorf("Expected second response name 'Grep', got '%s'", name1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_PreservesExistingName(t *testing.T) {
|
||||
// When functionResponse already has a valid name, it should be preserved.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
name := funcContent.Get("parts.0.functionResponse.name").String()
|
||||
if name != "Bash" {
|
||||
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_MoreResponsesThanCalls(t *testing.T) {
|
||||
// If there are more function responses than calls, unmatched extras are discarded by grouping.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Bash", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
|
||||
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContent gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContent = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if !funcContent.Exists() {
|
||||
t.Fatal("function role content should exist in output")
|
||||
}
|
||||
|
||||
// First response should be backfilled from the call
|
||||
name0 := funcContent.Get("parts.0.functionResponse.name").String()
|
||||
if name0 != "Bash" {
|
||||
t.Errorf("Expected first response name 'Bash', got '%s'", name0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixCLIToolResponse_MultipleGroupsFIFO(t *testing.T) {
|
||||
// Two sequential function call groups should be matched FIFO.
|
||||
input := `{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Read", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "file content"}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "Grep", "args": {}}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{"functionResponse": {"name": "", "response": {"result": "match"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := fixCLIToolResponse(input)
|
||||
if err != nil {
|
||||
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||
}
|
||||
|
||||
contents := gjson.Get(result, "request.contents").Array()
|
||||
var funcContents []gjson.Result
|
||||
for _, c := range contents {
|
||||
if c.Get("role").String() == "function" {
|
||||
funcContents = append(funcContents, c)
|
||||
}
|
||||
}
|
||||
if len(funcContents) != 2 {
|
||||
t.Fatalf("Expected 2 function contents, got %d", len(funcContents))
|
||||
}
|
||||
|
||||
name0 := funcContents[0].Get("parts.0.functionResponse.name").String()
|
||||
name1 := funcContents[1].Get("parts.0.functionResponse.name").String()
|
||||
if name0 != "Read" {
|
||||
t.Errorf("Expected first group name 'Read', got '%s'", name0)
|
||||
}
|
||||
if name1 != "Grep" {
|
||||
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ package gemini
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -29,8 +29,8 @@ import (
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - []string: The transformed request data in Gemini API format
|
||||
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
|
||||
// - [][]byte: The transformed response data in Gemini API format.
|
||||
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
}
|
||||
@@ -44,22 +44,22 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
|
||||
chunk = restoreUsageMetadata(chunk)
|
||||
}
|
||||
} else {
|
||||
chunkTemplate := "[]"
|
||||
chunkTemplate := []byte("[]")
|
||||
responseResult := gjson.ParseBytes(chunk)
|
||||
if responseResult.IsArray() {
|
||||
responseResultItems := responseResult.Array()
|
||||
for i := 0; i < len(responseResultItems); i++ {
|
||||
responseResultItem := responseResultItems[i]
|
||||
if responseResultItem.Get("response").Exists() {
|
||||
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
||||
chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
chunk = []byte(chunkTemplate)
|
||||
chunk = chunkTemplate
|
||||
}
|
||||
return []string{string(chunk)}
|
||||
return [][]byte{chunk}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
|
||||
@@ -73,18 +73,18 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response containing the response data
|
||||
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Gemini-compatible JSON response containing the response data.
|
||||
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
chunk := restoreUsageMetadata([]byte(responseResult.Raw))
|
||||
return string(chunk)
|
||||
return chunk
|
||||
}
|
||||
return string(rawJSON)
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.
|
||||
|
||||
@@ -59,8 +59,8 @@ func TestConvertAntigravityResponseToGeminiNonStream(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", result, tt.expected)
|
||||
if string(result) != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", string(result), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -87,8 +87,8 @@ func TestConvertAntigravityResponseToGeminiStream(t *testing.T) {
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
if results[0] != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", results[0], tt.expected)
|
||||
if string(results[0]) != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", string(results[0]), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -212,6 +212,33 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else {
|
||||
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
|
||||
}
|
||||
case "input_audio":
|
||||
audioData := item.Get("input_audio.data").String()
|
||||
audioFormat := item.Get("input_audio.format").String()
|
||||
if audioData != "" {
|
||||
audioMimeMap := map[string]string{
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"webm": "audio/webm",
|
||||
"pcm16": "audio/pcm",
|
||||
"g711_ulaw": "audio/basic",
|
||||
"g711_alaw": "audio/basic",
|
||||
}
|
||||
mimeType := "audio/wav"
|
||||
if audioFormat != "" {
|
||||
if mapped, ok := audioMimeMap[audioFormat]; ok {
|
||||
mimeType = mapped
|
||||
} else {
|
||||
mimeType = "audio/" + audioFormat
|
||||
}
|
||||
}
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData)
|
||||
p++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -327,31 +354,35 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if errRename != nil {
|
||||
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
|
||||
var errSet error
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
||||
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
fnRaw = string(fnRawBytes)
|
||||
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw = string(fnRawBytes)
|
||||
} else {
|
||||
fnRaw = renamed
|
||||
}
|
||||
} else {
|
||||
var errSet error
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
||||
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
fnRaw = string(fnRawBytes)
|
||||
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw = string(fnRawBytes)
|
||||
}
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
|
||||
@@ -44,8 +44,8 @@ var functionCallIDCounter uint64
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of OpenAI-compatible JSON responses
|
||||
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &convertCliResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
@@ -54,15 +54,15 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
}
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
template, _ = sjson.SetBytes(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
@@ -71,14 +71,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
if err == nil {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
|
||||
@@ -90,21 +90,21 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
@@ -141,33 +141,33 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
|
||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
|
||||
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
||||
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
||||
functionCallIndex = len(toolCallsResult.Array())
|
||||
} else {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
}
|
||||
|
||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||
functionCallTemplate := []byte(`{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||
} else if inlineDataResult.Exists() {
|
||||
data := inlineDataResult.Get("data").String()
|
||||
if data == "" {
|
||||
@@ -181,16 +181,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
imagesResult := gjson.GetBytes(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`))
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array())
|
||||
imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`)
|
||||
imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -212,11 +212,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
} else {
|
||||
finishReason = "stop"
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
|
||||
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
|
||||
@@ -231,11 +231,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
// - param: A pointer to a parameter object for the conversion
|
||||
//
|
||||
// Returns:
|
||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
|
||||
}
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
||||
if len(result1) != 1 {
|
||||
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
|
||||
}
|
||||
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
|
||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
|
||||
}
|
||||
@@ -33,13 +33,13 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
||||
if len(result2) != 1 {
|
||||
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
|
||||
}
|
||||
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||
if fr2 != "tool_calls" {
|
||||
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
|
||||
}
|
||||
|
||||
// Verify native_finish_reason is lowercase upstream value
|
||||
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
|
||||
nfr2 := gjson.GetBytes(result2[0], "choices.0.native_finish_reason").String()
|
||||
if nfr2 != "stop" {
|
||||
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func TestFinishReasonStopForNormalText(t *testing.T) {
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "stop" (no tool calls were made)
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "stop" {
|
||||
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func TestFinishReasonMaxTokens(t *testing.T) {
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "max_tokens"
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "max_tokens" {
|
||||
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "tool_calls" {
|
||||
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
|
||||
}
|
||||
@@ -111,7 +111,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
||||
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
|
||||
// Verify no finish_reason on intermediate chunk
|
||||
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
|
||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify no finish_reason on intermediate chunk
|
||||
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
|
||||
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason")
|
||||
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
|
||||
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
rawJSON = []byte(responseResult.Raw)
|
||||
@@ -15,7 +15,7 @@ func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName
|
||||
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
}
|
||||
|
||||
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
rawJSON = []byte(responseResult.Raw)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"context"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
|
||||
"github.com/tidwall/sjson"
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
)
|
||||
|
||||
// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format.
|
||||
@@ -23,15 +23,13 @@ import (
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
|
||||
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
// Wrap each converted response in a "response" object to match Gemini CLI API structure
|
||||
newOutputs := make([]string, 0)
|
||||
newOutputs := make([][]byte, 0, len(outputs))
|
||||
for i := 0; i < len(outputs); i++ {
|
||||
json := `{"response": {}}`
|
||||
output, _ := sjson.SetRaw(json, "response", outputs[i])
|
||||
newOutputs = append(newOutputs, output)
|
||||
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
|
||||
}
|
||||
return newOutputs
|
||||
}
|
||||
@@ -47,15 +45,13 @@ func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, ori
|
||||
// - param: A pointer to a parameter object for the conversion
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
strJSON := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
// - []byte: A Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
out := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
// Wrap the converted response in a "response" object to match Gemini CLI API structure
|
||||
json := `{"response": {}}`
|
||||
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
|
||||
return strJSON
|
||||
return translatorcommon.WrapGeminiCLIResponse(out)
|
||||
}
|
||||
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) string {
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
|
||||
return GeminiTokenCount(ctx, count)
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
|
||||
|
||||
// Base Claude message payload
|
||||
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
|
||||
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -87,20 +87,20 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
var pendingToolIDs []string
|
||||
|
||||
// Model mapping to specify which Claude Code model to use
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Generation config extraction from Gemini format
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
// Max output tokens configuration
|
||||
if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
|
||||
}
|
||||
// Temperature setting for controlling response randomness
|
||||
if temp := genConfig.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
out, _ = sjson.SetBytes(out, "temperature", temp.Float())
|
||||
} else if topP := genConfig.Get("topP"); topP.Exists() {
|
||||
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
|
||||
}
|
||||
// Stop sequences configuration for custom termination conditions
|
||||
if stopSeqs := genConfig.Get("stopSequences"); stopSeqs.Exists() && stopSeqs.IsArray() {
|
||||
@@ -110,7 +110,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
return true
|
||||
})
|
||||
if len(stopSequences) > 0 {
|
||||
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
|
||||
out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
|
||||
}
|
||||
}
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
@@ -132,30 +132,30 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
switch level {
|
||||
case "":
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
default:
|
||||
if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok {
|
||||
level = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", level)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "output_config.effort", level)
|
||||
}
|
||||
} else {
|
||||
switch level {
|
||||
case "":
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
default:
|
||||
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -169,37 +169,37 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if supportsAdaptive {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
default:
|
||||
level, ok := thinking.ConvertBudgetToLevel(budget)
|
||||
if ok {
|
||||
if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM {
|
||||
level = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", level)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "output_config.effort", level)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -220,9 +220,9 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
})
|
||||
if systemText.Len() > 0 {
|
||||
// Create system message in Claude Code format
|
||||
systemMessage := `{"role":"user","content":[{"type":"text","text":""}]}`
|
||||
systemMessage, _ = sjson.Set(systemMessage, "content.0.text", systemText.String())
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", systemMessage)
|
||||
systemMessage := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
|
||||
systemMessage, _ = sjson.SetBytes(systemMessage, "content.0.text", systemText.String())
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", systemMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,42 +245,42 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Create message structure in Claude Code format
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg := []byte(`{"role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
|
||||
if parts := content.Get("parts"); parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// Text content conversion
|
||||
if text := part.Get("text"); text.Exists() {
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent, _ = sjson.Set(textContent, "text", text.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
|
||||
textContent := []byte(`{"type":"text","text":""}`)
|
||||
textContent, _ = sjson.SetBytes(textContent, "text", text.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
|
||||
return true
|
||||
}
|
||||
|
||||
// Function call (from model/assistant) conversion to tool use
|
||||
if fc := part.Get("functionCall"); fc.Exists() && role == "assistant" {
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
|
||||
// Generate a unique tool ID and enqueue it for later matching
|
||||
// with the corresponding functionResponse
|
||||
toolID := genToolCallID()
|
||||
pendingToolIDs = append(pendingToolIDs, toolID)
|
||||
toolUse, _ = sjson.Set(toolUse, "id", toolID)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "id", toolID)
|
||||
|
||||
if name := fc.Get("name"); name.Exists() {
|
||||
toolUse, _ = sjson.Set(toolUse, "name", name.String())
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "name", name.String())
|
||||
}
|
||||
if args := fc.Get("args"); args.Exists() && args.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(args.Raw))
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
|
||||
return true
|
||||
}
|
||||
|
||||
// Function response (from user) conversion to tool result
|
||||
if fr := part.Get("functionResponse"); fr.Exists() {
|
||||
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
|
||||
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
|
||||
|
||||
// Attach the oldest queued tool_id to pair the response
|
||||
// with its call. If the queue is empty, generate a new id.
|
||||
@@ -293,41 +293,41 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Fallback: generate new ID if no pending tool_use found
|
||||
toolID = genToolCallID()
|
||||
}
|
||||
toolResult, _ = sjson.Set(toolResult, "tool_use_id", toolID)
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", toolID)
|
||||
|
||||
// Extract result content from the function response
|
||||
if result := fr.Get("response.result"); result.Exists() {
|
||||
toolResult, _ = sjson.Set(toolResult, "content", result.String())
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "content", result.String())
|
||||
} else if response := fr.Get("response"); response.Exists() {
|
||||
toolResult, _ = sjson.Set(toolResult, "content", response.Raw)
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "content", response.Raw)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", toolResult)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolResult)
|
||||
return true
|
||||
}
|
||||
|
||||
// Image content (inline_data) conversion to Claude Code format
|
||||
if inlineData := part.Get("inline_data"); inlineData.Exists() {
|
||||
imageContent := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imageContent := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
if mimeType := inlineData.Get("mime_type"); mimeType.Exists() {
|
||||
imageContent, _ = sjson.Set(imageContent, "source.media_type", mimeType.String())
|
||||
imageContent, _ = sjson.SetBytes(imageContent, "source.media_type", mimeType.String())
|
||||
}
|
||||
if data := inlineData.Get("data"); data.Exists() {
|
||||
imageContent, _ = sjson.Set(imageContent, "source.data", data.String())
|
||||
imageContent, _ = sjson.SetBytes(imageContent, "source.data", data.String())
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imageContent)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", imageContent)
|
||||
return true
|
||||
}
|
||||
|
||||
// File data conversion to text content with file info
|
||||
if fileData := part.Get("file_data"); fileData.Exists() {
|
||||
// For file data, we'll convert to text content with file info
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent := []byte(`{"type":"text","text":""}`)
|
||||
fileInfo := "File: " + fileData.Get("file_uri").String()
|
||||
if mimeType := fileData.Get("mime_type"); mimeType.Exists() {
|
||||
fileInfo += " (Type: " + mimeType.String() + ")"
|
||||
}
|
||||
textContent, _ = sjson.Set(textContent, "text", fileInfo)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
|
||||
textContent, _ = sjson.SetBytes(textContent, "text", fileInfo)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -336,8 +336,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Only add message if it has content
|
||||
if contentArray := gjson.Get(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
if contentArray := gjson.GetBytes(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -351,29 +351,29 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
if funcDecls := tool.Get("functionDeclarations"); funcDecls.Exists() && funcDecls.IsArray() {
|
||||
funcDecls.ForEach(func(_, funcDecl gjson.Result) bool {
|
||||
anthropicTool := `{"name":"","description":"","input_schema":{}}`
|
||||
anthropicTool := []byte(`{"name":"","description":"","input_schema":{}}`)
|
||||
|
||||
if name := funcDecl.Get("name"); name.Exists() {
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "name", name.String())
|
||||
anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", name.String())
|
||||
}
|
||||
if desc := funcDecl.Get("description"); desc.Exists() {
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "description", desc.String())
|
||||
anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", desc.String())
|
||||
}
|
||||
if params := funcDecl.Get("parameters"); params.Exists() {
|
||||
// Clean up the parameters schema for Claude Code compatibility
|
||||
cleaned := params.Raw
|
||||
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
|
||||
cleaned := []byte(params.Raw)
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
|
||||
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
|
||||
} else if params = funcDecl.Get("parametersJsonSchema"); params.Exists() {
|
||||
// Clean up the parameters schema for Claude Code compatibility
|
||||
cleaned := params.Raw
|
||||
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
|
||||
cleaned := []byte(params.Raw)
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
|
||||
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
|
||||
}
|
||||
|
||||
anthropicTools = append(anthropicTools, gjson.Parse(anthropicTool).Value())
|
||||
anthropicTools = append(anthropicTools, gjson.ParseBytes(anthropicTool).Value())
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -381,7 +381,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
})
|
||||
|
||||
if len(anthropicTools) > 0 {
|
||||
out, _ = sjson.Set(out, "tools", anthropicTools)
|
||||
out, _ = sjson.SetBytes(out, "tools", anthropicTools)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,27 +391,27 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if mode := funcCalling.Get("mode"); mode.Exists() {
|
||||
switch mode.String() {
|
||||
case "AUTO":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
|
||||
case "NONE":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"none"}`))
|
||||
case "ANY":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream setting configuration
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
out, _ = sjson.SetBytes(out, "stream", stream)
|
||||
|
||||
// Convert tool parameter types to lowercase for Claude Code compatibility
|
||||
var pathsToLower []string
|
||||
toolsResult := gjson.Get(out, "tools")
|
||||
toolsResult := gjson.GetBytes(out, "tools")
|
||||
util.Walk(toolsResult, "", "type", &pathsToLower)
|
||||
for _, p := range pathsToLower {
|
||||
fullPath := fmt.Sprintf("tools.%s", p)
|
||||
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
|
||||
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -30,7 +30,7 @@ type ConvertAnthropicResponseToGeminiParams struct {
|
||||
Model string
|
||||
CreatedAt int64
|
||||
ResponseID string
|
||||
LastStorageOutput string
|
||||
LastStorageOutput []byte
|
||||
IsStreaming bool
|
||||
|
||||
// Streaming state for tool_use assembly
|
||||
@@ -52,8 +52,8 @@ type ConvertAnthropicResponseToGeminiParams struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
|
||||
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Gemini-compatible JSON responses
|
||||
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertAnthropicResponseToGeminiParams{
|
||||
Model: modelName,
|
||||
@@ -63,7 +63,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
@@ -71,24 +71,24 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
eventType := root.Get("type").String()
|
||||
|
||||
// Base Gemini response template with default values
|
||||
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
|
||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
|
||||
|
||||
// Set model version
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).Model != "" {
|
||||
// Map Claude model names back to Gemini model names
|
||||
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
|
||||
}
|
||||
|
||||
// Set response ID and creation time
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID != "" {
|
||||
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
|
||||
}
|
||||
|
||||
// Set creation time to current time if not provided
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt == 0 {
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt = time.Now().Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
@@ -97,7 +97,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID = message.Get("id").String()
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).Model = message.Get("model").String()
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "content_block_start":
|
||||
// Start of a content block - record tool_use name by index for functionCall assembly
|
||||
@@ -112,7 +112,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "content_block_delta":
|
||||
// Handle content delta (text, thinking, or tool use arguments)
|
||||
@@ -123,16 +123,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
case "text_delta":
|
||||
// Regular text content delta for normal response text
|
||||
if text := delta.Get("text"); text.Exists() && text.String() != "" {
|
||||
textPart := `{"text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", text.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", textPart)
|
||||
textPart := []byte(`{"text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", text.String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", textPart)
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Thinking/reasoning content delta for models with reasoning capabilities
|
||||
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
|
||||
thinkingPart := `{"thought":true,"text":""}`
|
||||
thinkingPart, _ = sjson.Set(thinkingPart, "text", text.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", thinkingPart)
|
||||
thinkingPart := []byte(`{"thought":true,"text":""}`)
|
||||
thinkingPart, _ = sjson.SetBytes(thinkingPart, "text", text.String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", thinkingPart)
|
||||
}
|
||||
case "input_json_delta":
|
||||
// Tool use input delta - accumulate partial_json by index for later assembly at content_block_stop
|
||||
@@ -149,10 +149,10 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
if pj := delta.Get("partial_json"); pj.Exists() {
|
||||
b.WriteString(pj.String())
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
|
||||
case "content_block_stop":
|
||||
// End of content block - finalize tool calls if any
|
||||
@@ -170,16 +170,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
}
|
||||
if name != "" || argsTrim != "" {
|
||||
functionCall := `{"functionCall":{"name":"","args":{}}}`
|
||||
functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
|
||||
if name != "" {
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
|
||||
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", name)
|
||||
}
|
||||
if argsTrim != "" {
|
||||
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsTrim)
|
||||
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsTrim))
|
||||
}
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = template
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
|
||||
// cleanup used state for this index
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs != nil {
|
||||
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs, idx)
|
||||
@@ -187,9 +187,9 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil {
|
||||
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames, idx)
|
||||
}
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "message_delta":
|
||||
// Handle message-level changes (like stop reason and usage information)
|
||||
@@ -197,15 +197,15 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
|
||||
switch stopReason.String() {
|
||||
case "end_turn":
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
case "tool_use":
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
case "max_tokens":
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "MAX_TOKENS")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "MAX_TOKENS")
|
||||
case "stop_sequence":
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
default:
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,35 +216,35 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
|
||||
// Set basic usage metadata according to Gemini API specification
|
||||
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
|
||||
|
||||
// Add cache-related token counts if present (Claude Code API cache fields)
|
||||
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
|
||||
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
|
||||
}
|
||||
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
|
||||
// Add cache read tokens to cached content count
|
||||
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
|
||||
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
|
||||
}
|
||||
|
||||
// Add thinking tokens if present (for models with reasoning capabilities)
|
||||
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
|
||||
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
|
||||
}
|
||||
|
||||
// Set traffic type (required by Gemini API)
|
||||
template, _ = sjson.Set(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
|
||||
}
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
case "message_stop":
|
||||
// Final message with usage information - no additional output needed
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
case "error":
|
||||
// Handle error responses and convert to Gemini error format
|
||||
errorMsg := root.Get("error.message").String()
|
||||
@@ -253,13 +253,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
|
||||
// Create error response in Gemini format
|
||||
errorResponse := `{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`
|
||||
errorResponse, _ = sjson.Set(errorResponse, "error.message", errorMsg)
|
||||
return []string{errorResponse}
|
||||
errorResponse := []byte(`{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`)
|
||||
errorResponse, _ = sjson.SetBytes(errorResponse, "error.message", errorMsg)
|
||||
return [][]byte{errorResponse}
|
||||
|
||||
default:
|
||||
// Unknown event type, return empty response
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,13 +275,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response containing all message content and metadata
|
||||
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Gemini-compatible JSON response containing all message content and metadata
|
||||
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
// Base Gemini response template for non-streaming with default values
|
||||
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
|
||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
|
||||
|
||||
// Set model version
|
||||
template, _ = sjson.Set(template, "modelVersion", modelName)
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", modelName)
|
||||
|
||||
streamingEvents := make([][]byte, 0)
|
||||
|
||||
@@ -304,15 +304,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
LastStorageOutput: "",
|
||||
LastStorageOutput: nil,
|
||||
IsStreaming: false,
|
||||
ToolUseNames: nil,
|
||||
ToolUseArgs: nil,
|
||||
}
|
||||
|
||||
// Process each streaming event and collect parts
|
||||
var allParts []string
|
||||
var finalUsageJSON string
|
||||
var allParts [][]byte
|
||||
var finalUsageJSON []byte
|
||||
var responseID string
|
||||
var createdAt int64
|
||||
|
||||
@@ -360,15 +360,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
case "text_delta":
|
||||
// Process regular text content
|
||||
if text := delta.Get("text"); text.Exists() && text.String() != "" {
|
||||
partJSON := `{"text":""}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", text.String())
|
||||
partJSON := []byte(`{"text":""}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
|
||||
allParts = append(allParts, partJSON)
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Process reasoning/thinking content
|
||||
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
|
||||
partJSON := `{"thought":true,"text":""}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", text.String())
|
||||
partJSON := []byte(`{"thought":true,"text":""}`)
|
||||
partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
|
||||
allParts = append(allParts, partJSON)
|
||||
}
|
||||
case "input_json_delta":
|
||||
@@ -402,12 +402,12 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
}
|
||||
}
|
||||
if name != "" || argsTrim != "" {
|
||||
functionCallJSON := `{"functionCall":{"name":"","args":{}}}`
|
||||
functionCallJSON := []byte(`{"functionCall":{"name":"","args":{}}}`)
|
||||
if name != "" {
|
||||
functionCallJSON, _ = sjson.Set(functionCallJSON, "functionCall.name", name)
|
||||
functionCallJSON, _ = sjson.SetBytes(functionCallJSON, "functionCall.name", name)
|
||||
}
|
||||
if argsTrim != "" {
|
||||
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
|
||||
functionCallJSON, _ = sjson.SetRawBytes(functionCallJSON, "functionCall.args", []byte(argsTrim))
|
||||
}
|
||||
allParts = append(allParts, functionCallJSON)
|
||||
// cleanup used state for this index
|
||||
@@ -422,35 +422,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
case "message_delta":
|
||||
// Extract final usage information using sjson for token counts and metadata
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
usageJSON := `{}`
|
||||
usageJSON := []byte(`{}`)
|
||||
|
||||
// Basic token counts for prompt and completion
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
|
||||
// Set basic usage metadata according to Gemini API specification
|
||||
usageJSON, _ = sjson.Set(usageJSON, "promptTokenCount", inputTokens)
|
||||
usageJSON, _ = sjson.Set(usageJSON, "candidatesTokenCount", outputTokens)
|
||||
usageJSON, _ = sjson.Set(usageJSON, "totalTokenCount", inputTokens+outputTokens)
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "promptTokenCount", inputTokens)
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "candidatesTokenCount", outputTokens)
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "totalTokenCount", inputTokens+outputTokens)
|
||||
|
||||
// Add cache-related token counts if present (Claude Code API cache fields)
|
||||
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
|
||||
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
|
||||
}
|
||||
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
|
||||
// Add cache read tokens to cached content count
|
||||
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
|
||||
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", totalCacheTokens)
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", totalCacheTokens)
|
||||
}
|
||||
|
||||
// Add thinking tokens if present (for models with reasoning capabilities)
|
||||
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
|
||||
usageJSON, _ = sjson.Set(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
|
||||
}
|
||||
|
||||
// Set traffic type (required by Gemini API)
|
||||
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
|
||||
usageJSON, _ = sjson.SetBytes(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
|
||||
|
||||
finalUsageJSON = usageJSON
|
||||
}
|
||||
@@ -459,10 +459,10 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Set response metadata
|
||||
if responseID != "" {
|
||||
template, _ = sjson.Set(template, "responseId", responseID)
|
||||
template, _ = sjson.SetBytes(template, "responseId", responseID)
|
||||
}
|
||||
if createdAt > 0 {
|
||||
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
|
||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
|
||||
}
|
||||
|
||||
// Consolidate consecutive text parts and thinking parts for cleaner output
|
||||
@@ -470,35 +470,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Set the consolidated parts array
|
||||
if len(consolidatedParts) > 0 {
|
||||
partsJSON := "[]"
|
||||
partsJSON := []byte(`[]`)
|
||||
for _, partJSON := range consolidatedParts {
|
||||
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
|
||||
partsJSON, _ = sjson.SetRawBytes(partsJSON, "-1", partJSON)
|
||||
}
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts", partsJSON)
|
||||
}
|
||||
|
||||
// Set usage metadata
|
||||
if finalUsageJSON != "" {
|
||||
template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
|
||||
if len(finalUsageJSON) > 0 {
|
||||
template, _ = sjson.SetRawBytes(template, "usageMetadata", finalUsageJSON)
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
// consolidateParts merges consecutive text parts and thinking parts to create a cleaner response.
|
||||
// This function processes the parts array to combine adjacent text elements and thinking elements
|
||||
// into single consolidated parts, which results in a more readable and efficient response structure.
|
||||
// Tool calls and other non-text parts are preserved as separate elements.
|
||||
func consolidateParts(parts []string) []string {
|
||||
func consolidateParts(parts [][]byte) [][]byte {
|
||||
if len(parts) == 0 {
|
||||
return parts
|
||||
}
|
||||
|
||||
var consolidated []string
|
||||
var consolidated [][]byte
|
||||
var currentTextPart strings.Builder
|
||||
var currentThoughtPart strings.Builder
|
||||
var hasText, hasThought bool
|
||||
@@ -506,8 +506,8 @@ func consolidateParts(parts []string) []string {
|
||||
flushText := func() {
|
||||
// Flush accumulated text content to the consolidated parts array
|
||||
if hasText && currentTextPart.Len() > 0 {
|
||||
textPartJSON := `{"text":""}`
|
||||
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
|
||||
textPartJSON := []byte(`{"text":""}`)
|
||||
textPartJSON, _ = sjson.SetBytes(textPartJSON, "text", currentTextPart.String())
|
||||
consolidated = append(consolidated, textPartJSON)
|
||||
currentTextPart.Reset()
|
||||
hasText = false
|
||||
@@ -517,8 +517,8 @@ func consolidateParts(parts []string) []string {
|
||||
flushThought := func() {
|
||||
// Flush accumulated thinking content to the consolidated parts array
|
||||
if hasThought && currentThoughtPart.Len() > 0 {
|
||||
thoughtPartJSON := `{"thought":true,"text":""}`
|
||||
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
|
||||
thoughtPartJSON := []byte(`{"thought":true,"text":""}`)
|
||||
thoughtPartJSON, _ = sjson.SetBytes(thoughtPartJSON, "text", currentThoughtPart.String())
|
||||
consolidated = append(consolidated, thoughtPartJSON)
|
||||
currentThoughtPart.Reset()
|
||||
hasThought = false
|
||||
@@ -526,7 +526,7 @@ func consolidateParts(parts []string) []string {
|
||||
}
|
||||
|
||||
for _, partJSON := range parts {
|
||||
part := gjson.Parse(partJSON)
|
||||
part := gjson.ParseBytes(partJSON)
|
||||
if !part.Exists() || !part.IsObject() {
|
||||
// Flush any pending parts and add this non-text part
|
||||
flushText()
|
||||
|
||||
@@ -61,7 +61,7 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
|
||||
|
||||
// Base Claude Code API template with default max_tokens value
|
||||
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
|
||||
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -79,20 +79,20 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if supportsAdaptive {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
default:
|
||||
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
|
||||
effort = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", effort)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "output_config.effort", effort)
|
||||
}
|
||||
} else {
|
||||
// Legacy/manual thinking (budget_tokens).
|
||||
@@ -100,13 +100,13 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -128,19 +128,19 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Model mapping to specify which Claude Code model to use
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Max tokens configuration with fallback to default value
|
||||
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
|
||||
}
|
||||
|
||||
// Temperature setting for controlling response randomness
|
||||
if temp := root.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
out, _ = sjson.SetBytes(out, "temperature", temp.Float())
|
||||
} else if topP := root.Get("top_p"); topP.Exists() {
|
||||
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
|
||||
}
|
||||
|
||||
// Stop sequences configuration for custom termination conditions
|
||||
@@ -152,15 +152,15 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
return true
|
||||
})
|
||||
if len(stopSequences) > 0 {
|
||||
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
|
||||
out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_sequences", []string{stop.String()})
|
||||
out, _ = sjson.SetBytes(out, "stop_sequences", []string{stop.String()})
|
||||
}
|
||||
}
|
||||
|
||||
// Stream configuration to enable or disable streaming responses
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
out, _ = sjson.SetBytes(out, "stream", stream)
|
||||
|
||||
// Process messages and transform them to Claude Code format
|
||||
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
|
||||
@@ -173,76 +173,39 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
switch role {
|
||||
case "system":
|
||||
if systemMessageIndex == -1 {
|
||||
systemMsg := `{"role":"user","content":[]}`
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", systemMsg)
|
||||
systemMsg := []byte(`{"role":"user","content":[]}`)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", systemMsg)
|
||||
systemMessageIndex = messageIndex
|
||||
messageIndex++
|
||||
}
|
||||
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", contentResult.String())
|
||||
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
|
||||
textPart := []byte(`{"type":"text","text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", contentResult.String())
|
||||
out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
|
||||
} else if contentResult.Exists() && contentResult.IsArray() {
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
|
||||
textPart := []byte(`{"type":"text","text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
|
||||
out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
case "user", "assistant":
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg := []byte(`{"role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
|
||||
// Handle content based on its type (string or array)
|
||||
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
|
||||
part := `{"type":"text","text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentResult.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
part := []byte(`{"type":"text","text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", contentResult.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
} else if contentResult.Exists() && contentResult.IsArray() {
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "text":
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
|
||||
|
||||
case "image_url":
|
||||
// Convert OpenAI image format to Claude Code format
|
||||
imageURL := part.Get("image_url.url").String()
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
// Extract base64 data and media type from data URL
|
||||
parts := strings.Split(imageURL, ",")
|
||||
if len(parts) == 2 {
|
||||
mediaTypePart := strings.Split(parts[0], ";")[0]
|
||||
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
||||
data := parts[1]
|
||||
|
||||
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.Set(imagePart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||
}
|
||||
}
|
||||
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||
if claudePart != "" {
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(claudePart))
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -258,9 +221,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
function := toolCall.Get("function")
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
|
||||
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "id", toolCallID)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "name", function.Get("name").String())
|
||||
|
||||
// Parse arguments for the tool call
|
||||
if args := function.Get("arguments"); args.Exists() {
|
||||
@@ -268,35 +231,40 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
if argsJSON.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
|
||||
} else {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
|
||||
}
|
||||
} else {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
|
||||
}
|
||||
} else {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
|
||||
}
|
||||
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
messageIndex++
|
||||
|
||||
case "tool":
|
||||
// Handle tool result messages conversion
|
||||
toolCallID := message.Get("tool_call_id").String()
|
||||
content := message.Get("content").String()
|
||||
toolContentResult := message.Get("content")
|
||||
|
||||
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
|
||||
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
|
||||
msg, _ = sjson.Set(msg, "content.0.content", content)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
msg := []byte(`{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "content.0.tool_use_id", toolCallID)
|
||||
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
|
||||
if toolResultContentRaw {
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.0.content", []byte(toolResultContent))
|
||||
} else {
|
||||
msg, _ = sjson.SetBytes(msg, "content.0.content", toolResultContent)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
messageIndex++
|
||||
}
|
||||
return true
|
||||
@@ -309,25 +277,25 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
if tool.Get("type").String() == "function" {
|
||||
function := tool.Get("function")
|
||||
anthropicTool := `{"name":"","description":""}`
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
|
||||
anthropicTool := []byte(`{"name":"","description":""}`)
|
||||
anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", function.Get("name").String())
|
||||
anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", function.Get("description").String())
|
||||
|
||||
// Convert parameters schema for the tool
|
||||
if parameters := function.Get("parameters"); parameters.Exists() {
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
|
||||
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
|
||||
} else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
|
||||
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
|
||||
out, _ = sjson.SetRawBytes(out, "tools.-1", anthropicTool)
|
||||
hasAnthropicTools = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !hasAnthropicTools {
|
||||
out, _ = sjson.Delete(out, "tools")
|
||||
out, _ = sjson.DeleteBytes(out, "tools")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -340,21 +308,128 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
case "none":
|
||||
// Don't set tool_choice, Claude Code will not use tools
|
||||
case "auto":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
|
||||
case "required":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
|
||||
}
|
||||
case gjson.JSON:
|
||||
// Specific tool choice mapping
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
functionName := toolChoice.Get("function.name").String()
|
||||
toolChoiceJSON := `{"type":"tool","name":""}`
|
||||
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
|
||||
toolChoiceJSON := []byte(`{"type":"tool","name":""}`)
|
||||
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", functionName)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
|
||||
switch part.Get("type").String() {
|
||||
case "text":
|
||||
textPart := []byte(`{"type":"text","text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
|
||||
return string(textPart)
|
||||
|
||||
case "image_url":
|
||||
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
docPart, _ = sjson.SetBytes(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.SetBytes(docPart, "source.data", data)
|
||||
return string(docPart)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func convertOpenAIImageURLToClaudePart(imageURL string) string {
|
||||
if imageURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.HasPrefix(imageURL, "data:") {
|
||||
parts := strings.SplitN(imageURL, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
mediaTypePart := strings.SplitN(parts[0], ";", 2)[0]
|
||||
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
||||
if mediaType == "" {
|
||||
mediaType = "application/octet-stream"
|
||||
}
|
||||
|
||||
imagePart := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
imagePart, _ = sjson.SetBytes(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.SetBytes(imagePart, "source.data", parts[1])
|
||||
return string(imagePart)
|
||||
}
|
||||
|
||||
imagePart := []byte(`{"type":"image","source":{"type":"url","url":""}}`)
|
||||
imagePart, _ = sjson.SetBytes(imagePart, "source.url", imageURL)
|
||||
return string(imagePart)
|
||||
}
|
||||
|
||||
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
|
||||
if !content.Exists() {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if content.Type == gjson.String {
|
||||
return content.String(), false
|
||||
}
|
||||
|
||||
if content.IsArray() {
|
||||
claudeContent := []byte("[]")
|
||||
partCount := 0
|
||||
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Type == gjson.String {
|
||||
textPart := []byte(`{"type":"text","text":""}`)
|
||||
textPart, _ = sjson.SetBytes(textPart, "text", part.String())
|
||||
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", textPart)
|
||||
partCount++
|
||||
return true
|
||||
}
|
||||
|
||||
claudePart := convertOpenAIContentPartToClaudePart(part)
|
||||
if claudePart != "" {
|
||||
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
|
||||
partCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if partCount > 0 || len(content.Array()) == 0 {
|
||||
return string(claudeContent), true
|
||||
}
|
||||
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
if content.IsObject() {
|
||||
claudePart := convertOpenAIContentPartToClaudePart(content)
|
||||
if claudePart != "" {
|
||||
claudeContent := []byte("[]")
|
||||
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
|
||||
return string(claudeContent), true
|
||||
}
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
return content.Raw, false
|
||||
}
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_work",
|
||||
"arguments": "{\"a\":1}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": [
|
||||
{"type": "text", "text": "tool ok"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolResult := messages[1].Get("content.0")
|
||||
if got := toolResult.Get("type").String(); got != "tool_result" {
|
||||
t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got)
|
||||
}
|
||||
if got := toolResult.Get("tool_use_id").String(); got != "call_1" {
|
||||
t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got)
|
||||
}
|
||||
|
||||
toolContent := toolResult.Get("content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "text" {
|
||||
t.Fatalf("Expected first tool_result part type %q, got %q", "text", got)
|
||||
}
|
||||
if got := toolContent.Get("0.text").String(); got != "tool ok" {
|
||||
t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got)
|
||||
}
|
||||
if got := toolContent.Get("1.type").String(); got != "image" {
|
||||
t.Fatalf("Expected second tool_result part type %q, got %q", "image", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.type").String(); got != "base64" {
|
||||
t.Fatalf("Expected image source type %q, got %q", "base64", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.media_type").String(); got != "image/png" {
|
||||
t.Fatalf("Expected image media type %q, got %q", "image/png", got)
|
||||
}
|
||||
if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" {
|
||||
t.Fatalf("Unexpected base64 image data: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "gpt-4.1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_work",
|
||||
"arguments": "{\"a\":1}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/tool.png"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
toolContent := messages[1].Get("content.0.content")
|
||||
if !toolContent.IsArray() {
|
||||
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
|
||||
}
|
||||
if got := toolContent.Get("0.type").String(); got != "image" {
|
||||
t.Fatalf("Expected tool_result part type %q, got %q", "image", got)
|
||||
}
|
||||
if got := toolContent.Get("0.source.type").String(); got != "url" {
|
||||
t.Fatalf("Expected image source type %q, got %q", "url", got)
|
||||
}
|
||||
if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" {
|
||||
t.Fatalf("Unexpected image URL: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -48,8 +48,8 @@ type ToolCallAccumulator struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of OpenAI-compatible JSON responses
|
||||
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
var localParam any
|
||||
if param == nil {
|
||||
param = &localParam
|
||||
@@ -63,7 +63,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
@@ -71,19 +71,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
eventType := root.Get("type").String()
|
||||
|
||||
// Base OpenAI streaming response template
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`)
|
||||
|
||||
// Set model
|
||||
if modelName != "" {
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
}
|
||||
|
||||
// Set response ID and creation time
|
||||
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID != "" {
|
||||
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
|
||||
}
|
||||
if (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt > 0 {
|
||||
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
@@ -93,19 +93,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID = message.Get("id").String()
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt = time.Now().Unix()
|
||||
|
||||
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
|
||||
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
|
||||
|
||||
// Set initial role to assistant for the response
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
|
||||
// Initialize tool calls accumulator for tracking tool call progress
|
||||
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
|
||||
}
|
||||
}
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
|
||||
case "content_block_start":
|
||||
// Start of a content block (text, tool use, or reasoning)
|
||||
@@ -128,10 +128,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
}
|
||||
|
||||
// Don't output anything yet - wait for complete tool call
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "content_block_delta":
|
||||
// Handle content delta (text, tool use arguments, or reasoning content)
|
||||
@@ -143,13 +143,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
case "text_delta":
|
||||
// Text content delta - send incremental text updates
|
||||
if text := delta.Get("text"); text.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", text.String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.content", text.String())
|
||||
hasContent = true
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Accumulate reasoning/thinking content
|
||||
if thinking := delta.Get("thinking"); thinking.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", thinking.String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", thinking.String())
|
||||
hasContent = true
|
||||
}
|
||||
case "input_json_delta":
|
||||
@@ -163,13 +163,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
}
|
||||
}
|
||||
// Don't output anything yet - wait for complete tool call
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
if hasContent {
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
} else {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
@@ -182,26 +182,26 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
if arguments == "" {
|
||||
arguments = "{}"
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.index", index)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.type", "function")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
|
||||
|
||||
// Clean up the accumulator for this index
|
||||
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "message_delta":
|
||||
// Handle message-level changes including stop reason and usage
|
||||
if delta := root.Get("delta"); delta.Exists() {
|
||||
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason = mapAnthropicStopReasonToOpenAI(stopReason.String())
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,34 +211,34 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", inputTokens+outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
}
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
|
||||
case "message_stop":
|
||||
// Final message event - no additional output needed
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "ping":
|
||||
// Ping events for keeping connection alive - no output needed
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
case "error":
|
||||
// Error event - format and return error response
|
||||
if errorData := root.Get("error"); errorData.Exists() {
|
||||
errorJSON := `{"error":{"message":"","type":""}}`
|
||||
errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
|
||||
errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
|
||||
return []string{errorJSON}
|
||||
errorJSON := []byte(`{"error":{"message":"","type":""}}`)
|
||||
errorJSON, _ = sjson.SetBytes(errorJSON, "error.message", errorData.Get("message").String())
|
||||
errorJSON, _ = sjson.SetBytes(errorJSON, "error.type", errorData.Get("type").String())
|
||||
return [][]byte{errorJSON}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
|
||||
default:
|
||||
// Unknown event type - ignore
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,8 +270,8 @@ func mapAnthropicStopReasonToOpenAI(anthropicReason string) string {
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
chunks := make([][]byte, 0)
|
||||
|
||||
lines := bytes.Split(rawJSON, []byte("\n"))
|
||||
@@ -283,7 +283,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
|
||||
// Base OpenAI non-streaming response template
|
||||
out := `{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
out := []byte(`{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`)
|
||||
|
||||
var messageID string
|
||||
var model string
|
||||
@@ -370,28 +370,28 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", inputTokens+outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set basic response fields including message ID, creation time, and model
|
||||
out, _ = sjson.Set(out, "id", messageID)
|
||||
out, _ = sjson.Set(out, "created", createdAt)
|
||||
out, _ = sjson.Set(out, "model", model)
|
||||
out, _ = sjson.SetBytes(out, "id", messageID)
|
||||
out, _ = sjson.SetBytes(out, "created", createdAt)
|
||||
out, _ = sjson.SetBytes(out, "model", model)
|
||||
|
||||
// Set message content by combining all text parts
|
||||
messageContent := strings.Join(contentParts, "")
|
||||
out, _ = sjson.Set(out, "choices.0.message.content", messageContent)
|
||||
out, _ = sjson.SetBytes(out, "choices.0.message.content", messageContent)
|
||||
|
||||
// Add reasoning content if available (following OpenAI reasoning format)
|
||||
if len(reasoningParts) > 0 {
|
||||
reasoningContent := strings.Join(reasoningParts, "")
|
||||
// Add reasoning as a separate field in the message
|
||||
out, _ = sjson.Set(out, "choices.0.message.reasoning", reasoningContent)
|
||||
out, _ = sjson.SetBytes(out, "choices.0.message.reasoning", reasoningContent)
|
||||
}
|
||||
|
||||
// Set tool calls if any were accumulated during processing
|
||||
@@ -417,19 +417,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
|
||||
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
|
||||
|
||||
out, _ = sjson.Set(out, idPath, accumulator.ID)
|
||||
out, _ = sjson.Set(out, typePath, "function")
|
||||
out, _ = sjson.Set(out, namePath, accumulator.Name)
|
||||
out, _ = sjson.Set(out, argumentsPath, arguments)
|
||||
out, _ = sjson.SetBytes(out, idPath, accumulator.ID)
|
||||
out, _ = sjson.SetBytes(out, typePath, "function")
|
||||
out, _ = sjson.SetBytes(out, namePath, accumulator.Name)
|
||||
out, _ = sjson.SetBytes(out, argumentsPath, arguments)
|
||||
toolCallsCount++
|
||||
}
|
||||
if toolCallsCount > 0 {
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
|
||||
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", "tool_calls")
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -49,7 +49,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
|
||||
|
||||
// Base Claude message payload
|
||||
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
|
||||
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -67,20 +67,20 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if supportsAdaptive {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Delete(out, "output_config.effort")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.DeleteBytes(out, "output_config.effort")
|
||||
default:
|
||||
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
|
||||
effort = mapped
|
||||
}
|
||||
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.Set(out, "output_config.effort", effort)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
|
||||
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
|
||||
out, _ = sjson.SetBytes(out, "output_config.effort", effort)
|
||||
}
|
||||
} else {
|
||||
// Legacy/manual thinking (budget_tokens).
|
||||
@@ -88,13 +88,13 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -114,15 +114,15 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
// Model
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Max tokens
|
||||
if mot := root.Get("max_output_tokens"); mot.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", mot.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tokens", mot.Int())
|
||||
}
|
||||
|
||||
// Stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
out, _ = sjson.SetBytes(out, "stream", stream)
|
||||
|
||||
// instructions -> as a leading message (use role user for Claude API compatibility)
|
||||
instructionsText := ""
|
||||
@@ -130,9 +130,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if instr := root.Get("instructions"); instr.Exists() && instr.Type == gjson.String {
|
||||
instructionsText = instr.String()
|
||||
if instructionsText != "" {
|
||||
sysMsg := `{"role":"user","content":""}`
|
||||
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
|
||||
sysMsg := []byte(`{"role":"user","content":""}`)
|
||||
sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,9 +156,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
}
|
||||
instructionsText = builder.String()
|
||||
if instructionsText != "" {
|
||||
sysMsg := `{"role":"user","content":""}`
|
||||
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
|
||||
sysMsg := []byte(`{"role":"user","content":""}`)
|
||||
sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
|
||||
extractedFromSystem = true
|
||||
}
|
||||
}
|
||||
@@ -193,9 +193,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if t := part.Get("text"); t.Exists() {
|
||||
txt := t.String()
|
||||
textAggregate.WriteString(txt)
|
||||
contentPart := `{"type":"text","text":""}`
|
||||
contentPart, _ = sjson.Set(contentPart, "text", txt)
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
contentPart := []byte(`{"type":"text","text":""}`)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "text", txt)
|
||||
partsJSON = append(partsJSON, string(contentPart))
|
||||
}
|
||||
if ptype == "input_text" {
|
||||
role = "user"
|
||||
@@ -208,7 +208,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
url = part.Get("url").String()
|
||||
}
|
||||
if url != "" {
|
||||
var contentPart string
|
||||
var contentPart []byte
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
trimmed := strings.TrimPrefix(url, "data:")
|
||||
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||
@@ -221,16 +221,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
data = mediaAndData[1]
|
||||
}
|
||||
if data != "" {
|
||||
contentPart = `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||
contentPart = []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
|
||||
}
|
||||
} else {
|
||||
contentPart = `{"type":"image","source":{"type":"url","url":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.url", url)
|
||||
contentPart = []byte(`{"type":"image","source":{"type":"url","url":""}}`)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.url", url)
|
||||
}
|
||||
if contentPart != "" {
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
if len(contentPart) > 0 {
|
||||
partsJSON = append(partsJSON, string(contentPart))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
@@ -252,10 +252,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
data = mediaAndData[1]
|
||||
}
|
||||
}
|
||||
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
contentPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
|
||||
partsJSON = append(partsJSON, string(contentPart))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
@@ -280,24 +280,24 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
if len(partsJSON) > 0 {
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg := []byte(`{"role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||
// Preserve legacy behavior for single text content
|
||||
msg, _ = sjson.Delete(msg, "content")
|
||||
msg, _ = sjson.DeleteBytes(msg, "content")
|
||||
textPart := gjson.Parse(partsJSON[0])
|
||||
msg, _ = sjson.Set(msg, "content", textPart.Get("text").String())
|
||||
msg, _ = sjson.SetBytes(msg, "content", textPart.Get("text").String())
|
||||
} else {
|
||||
for _, partJSON := range partsJSON {
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", partJSON)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(partJSON))
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
} else if textAggregate.Len() > 0 || role == "system" {
|
||||
msg := `{"role":"","content":""}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg, _ = sjson.Set(msg, "content", textAggregate.String())
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
msg := []byte(`{"role":"","content":""}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
msg, _ = sjson.SetBytes(msg, "content", textAggregate.String())
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
|
||||
}
|
||||
|
||||
case "function_call":
|
||||
@@ -309,31 +309,31 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
name := item.Get("name").String()
|
||||
argsStr := item.Get("arguments").String()
|
||||
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse, _ = sjson.Set(toolUse, "id", callID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", name)
|
||||
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "id", callID)
|
||||
toolUse, _ = sjson.SetBytes(toolUse, "name", name)
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
if argsJSON.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
|
||||
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
|
||||
}
|
||||
}
|
||||
|
||||
asst := `{"role":"assistant","content":[]}`
|
||||
asst, _ = sjson.SetRaw(asst, "content.-1", toolUse)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", asst)
|
||||
asst := []byte(`{"role":"assistant","content":[]}`)
|
||||
asst, _ = sjson.SetRawBytes(asst, "content.-1", toolUse)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", asst)
|
||||
|
||||
case "function_call_output":
|
||||
// Map to user tool_result
|
||||
callID := item.Get("call_id").String()
|
||||
outputStr := item.Get("output").String()
|
||||
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
|
||||
toolResult, _ = sjson.Set(toolResult, "tool_use_id", callID)
|
||||
toolResult, _ = sjson.Set(toolResult, "content", outputStr)
|
||||
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", callID)
|
||||
toolResult, _ = sjson.SetBytes(toolResult, "content", outputStr)
|
||||
|
||||
usr := `{"role":"user","content":[]}`
|
||||
usr, _ = sjson.SetRaw(usr, "content.-1", toolResult)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", usr)
|
||||
usr := []byte(`{"role":"user","content":[]}`)
|
||||
usr, _ = sjson.SetRawBytes(usr, "content.-1", toolResult)
|
||||
out, _ = sjson.SetRawBytes(out, "messages.-1", usr)
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -341,27 +341,27 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
|
||||
// tools mapping: parameters -> input_schema
|
||||
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
|
||||
toolsJSON := "[]"
|
||||
toolsJSON := []byte("[]")
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
tJSON := `{"name":"","description":"","input_schema":{}}`
|
||||
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
|
||||
if n := tool.Get("name"); n.Exists() {
|
||||
tJSON, _ = sjson.Set(tJSON, "name", n.String())
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
|
||||
}
|
||||
if d := tool.Get("description"); d.Exists() {
|
||||
tJSON, _ = sjson.Set(tJSON, "description", d.String())
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
|
||||
}
|
||||
|
||||
if params := tool.Get("parameters"); params.Exists() {
|
||||
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
|
||||
} else if params = tool.Get("parametersJsonSchema"); params.Exists() {
|
||||
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
|
||||
}
|
||||
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "-1", tJSON)
|
||||
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
|
||||
return true
|
||||
})
|
||||
if gjson.Parse(toolsJSON).IsArray() && len(gjson.Parse(toolsJSON).Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "tools", toolsJSON)
|
||||
if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "tools", toolsJSON)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -371,23 +371,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
case gjson.String:
|
||||
switch toolChoice.String() {
|
||||
case "auto":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
|
||||
case "none":
|
||||
// Leave unset; implies no tools
|
||||
case "required":
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
|
||||
}
|
||||
case gjson.JSON:
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
fn := toolChoice.Get("function.name").String()
|
||||
toolChoiceJSON := `{"name":"","type":"tool"}`
|
||||
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
|
||||
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
|
||||
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
|
||||
}
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -50,12 +51,12 @@ func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func emitEvent(event string, payload string) string {
|
||||
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
||||
func emitEvent(event string, payload []byte) []byte {
|
||||
return translatorcommon.SSEEventData(event, payload)
|
||||
}
|
||||
|
||||
// ConvertClaudeResponseToOpenAIResponses converts Claude SSE to OpenAI Responses SSE events.
|
||||
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &claudeToResponsesState{FuncArgsBuf: make(map[int]*strings.Builder), FuncNames: make(map[int]string), FuncCallIDs: make(map[int]string)}
|
||||
}
|
||||
@@ -63,12 +64,12 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
|
||||
// Expect `data: {..}` from Claude clients
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
ev := root.Get("type").String()
|
||||
var out []string
|
||||
var out [][]byte
|
||||
|
||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||
|
||||
@@ -105,16 +106,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
}
|
||||
// response.created
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
|
||||
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.SetBytes(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.SetBytes(created, "response.created_at", st.CreatedAt)
|
||||
out = append(out, emitEvent("response.created", created))
|
||||
// response.in_progress
|
||||
inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`
|
||||
inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq())
|
||||
inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID)
|
||||
inprog, _ = sjson.Set(inprog, "response.created_at", st.CreatedAt)
|
||||
inprog := []byte(`{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`)
|
||||
inprog, _ = sjson.SetBytes(inprog, "sequence_number", nextSeq())
|
||||
inprog, _ = sjson.SetBytes(inprog, "response.id", st.ResponseID)
|
||||
inprog, _ = sjson.SetBytes(inprog, "response.created_at", st.CreatedAt)
|
||||
out = append(out, emitEvent("response.in_progress", inprog))
|
||||
}
|
||||
case "content_block_start":
|
||||
@@ -128,25 +129,25 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
// open message item + content part
|
||||
st.InTextBlock = true
|
||||
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.Set(item, "item.id", st.CurrentMsgID)
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "item.id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
|
||||
part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
part, _ = sjson.Set(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.Set(part, "item_id", st.CurrentMsgID)
|
||||
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.SetBytes(part, "item_id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.content_part.added", part))
|
||||
} else if typ == "tool_use" {
|
||||
st.InFuncBlock = true
|
||||
st.CurrentFCID = cb.Get("id").String()
|
||||
name := cb.Get("name").String()
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.Set(item, "output_index", idx)
|
||||
item, _ = sjson.Set(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
item, _ = sjson.Set(item, "item.call_id", st.CurrentFCID)
|
||||
item, _ = sjson.Set(item, "item.name", name)
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
item, _ = sjson.SetBytes(item, "item.call_id", st.CurrentFCID)
|
||||
item, _ = sjson.SetBytes(item, "item.name", name)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
if st.FuncArgsBuf[idx] == nil {
|
||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||
@@ -160,16 +161,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
st.ReasoningIndex = idx
|
||||
st.ReasoningBuf.Reset()
|
||||
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.Set(item, "output_index", idx)
|
||||
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningItemID)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
// add a summary part placeholder
|
||||
part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
|
||||
part, _ = sjson.Set(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.Set(part, "item_id", st.ReasoningItemID)
|
||||
part, _ = sjson.Set(part, "output_index", idx)
|
||||
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.SetBytes(part, "item_id", st.ReasoningItemID)
|
||||
part, _ = sjson.SetBytes(part, "output_index", idx)
|
||||
out = append(out, emitEvent("response.reasoning_summary_part.added", part))
|
||||
st.ReasoningPartAdded = true
|
||||
}
|
||||
@@ -181,10 +182,10 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
dt := d.Get("type").String()
|
||||
if dt == "text_delta" {
|
||||
if t := d.Get("text"); t.Exists() {
|
||||
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", st.CurrentMsgID)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.output_text.delta", msg))
|
||||
// aggregate text for response.output
|
||||
st.TextBuf.WriteString(t.String())
|
||||
@@ -196,22 +197,22 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||
}
|
||||
st.FuncArgsBuf[idx].WriteString(pj.String())
|
||||
msg := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
msg, _ = sjson.Set(msg, "output_index", idx)
|
||||
msg, _ = sjson.Set(msg, "delta", pj.String())
|
||||
msg := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", idx)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", pj.String())
|
||||
out = append(out, emitEvent("response.function_call_arguments.delta", msg))
|
||||
}
|
||||
} else if dt == "thinking_delta" {
|
||||
if st.ReasoningActive {
|
||||
if t := d.Get("thinking"); t.Exists() {
|
||||
st.ReasoningBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
msg := []byte(`{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
}
|
||||
@@ -219,17 +220,17 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
case "content_block_stop":
|
||||
idx := int(root.Get("index").Int())
|
||||
if st.InTextBlock {
|
||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||
final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
|
||||
final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
st.InTextBlock = false
|
||||
} else if st.InFuncBlock {
|
||||
@@ -239,34 +240,34 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
args = buf.String()
|
||||
}
|
||||
}
|
||||
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
|
||||
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
|
||||
fcDone, _ = sjson.Set(fcDone, "arguments", args)
|
||||
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", idx)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
|
||||
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
|
||||
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
|
||||
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", st.CurrentFCID)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[idx])
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
st.InFuncBlock = false
|
||||
} else if st.ReasoningActive {
|
||||
full := st.ReasoningBuf.String()
|
||||
textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq())
|
||||
textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningItemID)
|
||||
textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex)
|
||||
textDone, _ = sjson.Set(textDone, "text", full)
|
||||
textDone := []byte(`{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`)
|
||||
textDone, _ = sjson.SetBytes(textDone, "sequence_number", nextSeq())
|
||||
textDone, _ = sjson.SetBytes(textDone, "item_id", st.ReasoningItemID)
|
||||
textDone, _ = sjson.SetBytes(textDone, "output_index", st.ReasoningIndex)
|
||||
textDone, _ = sjson.SetBytes(textDone, "text", full)
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.done", textDone))
|
||||
partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningItemID)
|
||||
partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex)
|
||||
partDone, _ = sjson.Set(partDone, "part.text", full)
|
||||
partDone := []byte(`{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", st.ReasoningItemID)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", st.ReasoningIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", full)
|
||||
out = append(out, emitEvent("response.reasoning_summary_part.done", partDone))
|
||||
st.ReasoningActive = false
|
||||
st.ReasoningPartAdded = false
|
||||
@@ -284,92 +285,92 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
case "message_stop":
|
||||
|
||||
completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`
|
||||
completed, _ = sjson.Set(completed, "sequence_number", nextSeq())
|
||||
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
|
||||
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
|
||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
|
||||
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
|
||||
completed, _ = sjson.SetBytes(completed, "response.created_at", st.CreatedAt)
|
||||
// Inject original request fields into response as per docs/response.completed.json
|
||||
|
||||
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
|
||||
if len(reqBytes) > 0 {
|
||||
req := gjson.ParseBytes(reqBytes)
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.instructions", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
|
||||
}
|
||||
if v := req.Get("max_output_tokens"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.max_output_tokens", v.Int())
|
||||
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
|
||||
}
|
||||
if v := req.Get("max_tool_calls"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.max_tool_calls", v.Int())
|
||||
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
|
||||
}
|
||||
if v := req.Get("model"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.model", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
|
||||
}
|
||||
if v := req.Get("parallel_tool_calls"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.parallel_tool_calls", v.Bool())
|
||||
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
|
||||
}
|
||||
if v := req.Get("previous_response_id"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.previous_response_id", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
|
||||
}
|
||||
if v := req.Get("prompt_cache_key"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.prompt_cache_key", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
|
||||
}
|
||||
if v := req.Get("reasoning"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.reasoning", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
|
||||
}
|
||||
if v := req.Get("safety_identifier"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.safety_identifier", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
|
||||
}
|
||||
if v := req.Get("service_tier"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.service_tier", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
|
||||
}
|
||||
if v := req.Get("store"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.store", v.Bool())
|
||||
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
|
||||
}
|
||||
if v := req.Get("temperature"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.temperature", v.Float())
|
||||
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
|
||||
}
|
||||
if v := req.Get("text"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.text", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
|
||||
}
|
||||
if v := req.Get("tool_choice"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.tool_choice", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
|
||||
}
|
||||
if v := req.Get("tools"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.tools", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
|
||||
}
|
||||
if v := req.Get("top_logprobs"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.top_logprobs", v.Int())
|
||||
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
|
||||
}
|
||||
if v := req.Get("top_p"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.top_p", v.Float())
|
||||
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
|
||||
}
|
||||
if v := req.Get("truncation"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.truncation", v.String())
|
||||
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
|
||||
}
|
||||
if v := req.Get("user"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.user", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
|
||||
}
|
||||
if v := req.Get("metadata"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.metadata", v.Value())
|
||||
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
|
||||
}
|
||||
}
|
||||
|
||||
// Build response.output from aggregated state
|
||||
outputsWrapper := `{"arr":[]}`
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
// reasoning item (if any)
|
||||
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
|
||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
// assistant message item (if any text)
|
||||
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
// function_call items (in ascending index order for determinism)
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
@@ -396,16 +397,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
if callID == "" && st.CurrentFCID != "" {
|
||||
callID = st.CurrentFCID
|
||||
}
|
||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.Set(item, "arguments", args)
|
||||
item, _ = sjson.Set(item, "call_id", callID)
|
||||
item, _ = sjson.Set(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
}
|
||||
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
|
||||
reasoningTokens := int64(0)
|
||||
@@ -414,15 +415,15 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
usagePresent := st.UsageSeen || reasoningTokens > 0
|
||||
if usagePresent {
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.InputTokens)
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.OutputTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.InputTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", 0)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.OutputTokens)
|
||||
if reasoningTokens > 0 {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
}
|
||||
total := st.InputTokens + st.OutputTokens
|
||||
if total > 0 || st.UsageSeen {
|
||||
completed, _ = sjson.Set(completed, "response.usage.total_tokens", total)
|
||||
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
|
||||
}
|
||||
}
|
||||
out = append(out, emitEvent("response.completed", completed))
|
||||
@@ -432,7 +433,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
|
||||
// ConvertClaudeResponseToOpenAIResponsesNonStream aggregates Claude SSE into a single OpenAI Responses JSON.
|
||||
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
// Aggregate Claude SSE lines into a single OpenAI Responses JSON (non-stream)
|
||||
// We follow the same aggregation logic as the streaming variant but produce
|
||||
// one final object matching docs/out.json structure.
|
||||
@@ -455,7 +456,7 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
}
|
||||
|
||||
// Base OpenAI Responses (non-stream) object
|
||||
out := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`
|
||||
out := []byte(`{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`)
|
||||
|
||||
// Aggregation state
|
||||
var (
|
||||
@@ -557,88 +558,88 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
}
|
||||
|
||||
// Populate base fields
|
||||
out, _ = sjson.Set(out, "id", responseID)
|
||||
out, _ = sjson.Set(out, "created_at", createdAt)
|
||||
out, _ = sjson.SetBytes(out, "id", responseID)
|
||||
out, _ = sjson.SetBytes(out, "created_at", createdAt)
|
||||
|
||||
// Inject request echo fields as top-level (similar to streaming variant)
|
||||
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
|
||||
if len(reqBytes) > 0 {
|
||||
req := gjson.ParseBytes(reqBytes)
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "instructions", v.String())
|
||||
out, _ = sjson.SetBytes(out, "instructions", v.String())
|
||||
}
|
||||
if v := req.Get("max_output_tokens"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "max_output_tokens", v.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_output_tokens", v.Int())
|
||||
}
|
||||
if v := req.Get("max_tool_calls"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tool_calls", v.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tool_calls", v.Int())
|
||||
}
|
||||
if v := req.Get("model"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "model", v.String())
|
||||
out, _ = sjson.SetBytes(out, "model", v.String())
|
||||
}
|
||||
if v := req.Get("parallel_tool_calls"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", v.Bool())
|
||||
out, _ = sjson.SetBytes(out, "parallel_tool_calls", v.Bool())
|
||||
}
|
||||
if v := req.Get("previous_response_id"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "previous_response_id", v.String())
|
||||
out, _ = sjson.SetBytes(out, "previous_response_id", v.String())
|
||||
}
|
||||
if v := req.Get("prompt_cache_key"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "prompt_cache_key", v.String())
|
||||
out, _ = sjson.SetBytes(out, "prompt_cache_key", v.String())
|
||||
}
|
||||
if v := req.Get("reasoning"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "reasoning", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "reasoning", v.Value())
|
||||
}
|
||||
if v := req.Get("safety_identifier"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "safety_identifier", v.String())
|
||||
out, _ = sjson.SetBytes(out, "safety_identifier", v.String())
|
||||
}
|
||||
if v := req.Get("service_tier"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "service_tier", v.String())
|
||||
out, _ = sjson.SetBytes(out, "service_tier", v.String())
|
||||
}
|
||||
if v := req.Get("store"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "store", v.Bool())
|
||||
out, _ = sjson.SetBytes(out, "store", v.Bool())
|
||||
}
|
||||
if v := req.Get("temperature"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", v.Float())
|
||||
out, _ = sjson.SetBytes(out, "temperature", v.Float())
|
||||
}
|
||||
if v := req.Get("text"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "text", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text", v.Value())
|
||||
}
|
||||
if v := req.Get("tool_choice"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "tool_choice", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "tool_choice", v.Value())
|
||||
}
|
||||
if v := req.Get("tools"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "tools", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "tools", v.Value())
|
||||
}
|
||||
if v := req.Get("top_logprobs"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "top_logprobs", v.Int())
|
||||
out, _ = sjson.SetBytes(out, "top_logprobs", v.Int())
|
||||
}
|
||||
if v := req.Get("top_p"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "top_p", v.Float())
|
||||
out, _ = sjson.SetBytes(out, "top_p", v.Float())
|
||||
}
|
||||
if v := req.Get("truncation"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "truncation", v.String())
|
||||
out, _ = sjson.SetBytes(out, "truncation", v.String())
|
||||
}
|
||||
if v := req.Get("user"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "user", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "user", v.Value())
|
||||
}
|
||||
if v := req.Get("metadata"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "metadata", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "metadata", v.Value())
|
||||
}
|
||||
}
|
||||
|
||||
// Build output array
|
||||
outputsWrapper := `{"arr":[]}`
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
if reasoningBuf.Len() > 0 {
|
||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", reasoningItemID)
|
||||
item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", reasoningItemID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", reasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if currentMsgID != "" || textBuf.Len() > 0 {
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", currentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", textBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", currentMsgID)
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", textBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
// Preserve index order
|
||||
@@ -659,28 +660,28 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
|
||||
item, _ = sjson.Set(item, "arguments", args)
|
||||
item, _ = sjson.Set(item, "call_id", st.id)
|
||||
item, _ = sjson.Set(item, "name", st.name)
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", st.id))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", st.id)
|
||||
item, _ = sjson.SetBytes(item, "name", st.name)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
}
|
||||
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
|
||||
out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
|
||||
// Usage
|
||||
total := inputTokens + outputTokens
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.total_tokens", total)
|
||||
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", total)
|
||||
if reasoningBuf.Len() > 0 {
|
||||
// Rough estimate similar to chat completions
|
||||
reasoningTokens := int64(len(reasoningBuf.String()) / 4)
|
||||
if reasoningTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,32 +36,41 @@ import (
|
||||
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
template := `{"model":"","instructions":"","input":[]}`
|
||||
template := []byte(`{"model":"","instructions":"","input":[]}`)
|
||||
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
|
||||
// Process system messages and convert them to input content format.
|
||||
systemsResult := rootResult.Get("system")
|
||||
if systemsResult.IsArray() {
|
||||
systemResults := systemsResult.Array()
|
||||
message := `{"type":"message","role":"developer","content":[]}`
|
||||
if systemsResult.Exists() {
|
||||
message := []byte(`{"type":"message","role":"developer","content":[]}`)
|
||||
contentIndex := 0
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemResult := systemResults[i]
|
||||
systemTypeResult := systemResult.Get("type")
|
||||
if systemTypeResult.String() == "text" {
|
||||
text := systemResult.Get("text").String()
|
||||
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||
continue
|
||||
|
||||
appendSystemText := func(text string) {
|
||||
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||
return
|
||||
}
|
||||
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
}
|
||||
|
||||
if systemsResult.Type == gjson.String {
|
||||
appendSystemText(systemsResult.String())
|
||||
} else if systemsResult.IsArray() {
|
||||
systemResults := systemsResult.Array()
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemResult := systemResults[i]
|
||||
if systemResult.Get("type").String() == "text" {
|
||||
appendSystemText(systemResult.Get("text").String())
|
||||
}
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if contentIndex > 0 {
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
template, _ = sjson.SetRawBytes(template, "input.-1", message)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,9 +83,9 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
messageResult := messageResults[i]
|
||||
messageRole := messageResult.Get("role").String()
|
||||
|
||||
newMessage := func() string {
|
||||
msg := `{"type": "message","role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", messageRole)
|
||||
newMessage := func() []byte {
|
||||
msg := []byte(`{"type":"message","role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", messageRole)
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -86,7 +95,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
flushMessage := func() {
|
||||
if hasContent {
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
template, _ = sjson.SetRawBytes(template, "input.-1", message)
|
||||
message = newMessage()
|
||||
contentIndex = 0
|
||||
hasContent = false
|
||||
@@ -98,15 +107,15 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
if messageRole == "assistant" {
|
||||
partType = "output_text"
|
||||
}
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
hasContent = true
|
||||
}
|
||||
|
||||
appendImageContent := func(dataURL string) {
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
|
||||
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
|
||||
contentIndex++
|
||||
hasContent = true
|
||||
}
|
||||
@@ -142,8 +151,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
case "tool_use":
|
||||
flushMessage()
|
||||
functionCallMessage := `{"type":"function_call"}`
|
||||
functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String())
|
||||
functionCallMessage := []byte(`{"type":"function_call"}`)
|
||||
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", messageContentResult.Get("id").String())
|
||||
{
|
||||
name := messageContentResult.Get("name").String()
|
||||
toolMap := buildReverseMapFromClaudeOriginalToShort(rawJSON)
|
||||
@@ -152,19 +161,19 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
functionCallMessage, _ = sjson.Set(functionCallMessage, "name", name)
|
||||
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "name", name)
|
||||
}
|
||||
functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
|
||||
template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage)
|
||||
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
|
||||
template, _ = sjson.SetRawBytes(template, "input.-1", functionCallMessage)
|
||||
case "tool_result":
|
||||
flushMessage()
|
||||
functionCallOutputMessage := `{"type":"function_call_output"}`
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
||||
functionCallOutputMessage := []byte(`{"type":"function_call_output"}`)
|
||||
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
||||
|
||||
contentResult := messageContentResult.Get("content")
|
||||
if contentResult.IsArray() {
|
||||
toolResultContentIndex := 0
|
||||
toolResultContent := `[]`
|
||||
toolResultContent := []byte(`[]`)
|
||||
contentResults := contentResult.Array()
|
||||
for k := 0; k < len(contentResults); k++ {
|
||||
toolResultContentType := contentResults[k].Get("type").String()
|
||||
@@ -185,27 +194,27 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
|
||||
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
|
||||
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
|
||||
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
|
||||
toolResultContentIndex++
|
||||
}
|
||||
}
|
||||
} else if toolResultContentType == "text" {
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
|
||||
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
|
||||
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
|
||||
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
|
||||
toolResultContentIndex++
|
||||
}
|
||||
}
|
||||
if toolResultContent != `[]` {
|
||||
functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent)
|
||||
if toolResultContentIndex > 0 {
|
||||
functionCallOutputMessage, _ = sjson.SetRawBytes(functionCallOutputMessage, "output", toolResultContent)
|
||||
} else {
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
}
|
||||
} else {
|
||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||
}
|
||||
|
||||
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
|
||||
template, _ = sjson.SetRawBytes(template, "input.-1", functionCallOutputMessage)
|
||||
}
|
||||
}
|
||||
flushMessage()
|
||||
@@ -220,8 +229,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Convert tools declarations to the expected format for the Codex API.
|
||||
toolsResult := rootResult.Get("tools")
|
||||
if toolsResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "tools", `[]`)
|
||||
template, _ = sjson.Set(template, "tool_choice", `auto`)
|
||||
template, _ = sjson.SetRawBytes(template, "tools", []byte(`[]`))
|
||||
template, _ = sjson.SetBytes(template, "tool_choice", `auto`)
|
||||
toolResults := toolsResult.Array()
|
||||
// Build short name map from declared tools
|
||||
var names []string
|
||||
@@ -237,11 +246,11 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Special handling: map Claude web search tool to Codex web_search
|
||||
if toolResult.Get("type").String() == "web_search_20250305" {
|
||||
// Replace the tool content entirely with {"type":"web_search"}
|
||||
template, _ = sjson.SetRaw(template, "tools.-1", `{"type":"web_search"}`)
|
||||
template, _ = sjson.SetRawBytes(template, "tools.-1", []byte(`{"type":"web_search"}`))
|
||||
continue
|
||||
}
|
||||
tool := toolResult.Raw
|
||||
tool, _ = sjson.Set(tool, "type", "function")
|
||||
tool := []byte(toolResult.Raw)
|
||||
tool, _ = sjson.SetBytes(tool, "type", "function")
|
||||
// Apply shortened name if needed
|
||||
if v := toolResult.Get("name"); v.Exists() {
|
||||
name := v.String()
|
||||
@@ -250,20 +259,26 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
tool, _ = sjson.Set(tool, "name", name)
|
||||
tool, _ = sjson.SetBytes(tool, "name", name)
|
||||
}
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw))
|
||||
tool, _ = sjson.Delete(tool, "input_schema")
|
||||
tool, _ = sjson.Delete(tool, "parameters.$schema")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||
tool, _ = sjson.Set(tool, "strict", false)
|
||||
template, _ = sjson.SetRaw(template, "tools.-1", tool)
|
||||
tool, _ = sjson.SetRawBytes(tool, "parameters", []byte(normalizeToolParameters(toolResult.Get("input_schema").Raw)))
|
||||
tool, _ = sjson.DeleteBytes(tool, "input_schema")
|
||||
tool, _ = sjson.DeleteBytes(tool, "parameters.$schema")
|
||||
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
||||
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
||||
tool, _ = sjson.SetBytes(tool, "strict", false)
|
||||
template, _ = sjson.SetRawBytes(template, "tools.-1", tool)
|
||||
}
|
||||
}
|
||||
|
||||
// Default to parallel tool calls unless tool_choice explicitly disables them.
|
||||
parallelToolCalls := true
|
||||
if disableParallelToolUse := rootResult.Get("tool_choice.disable_parallel_tool_use"); disableParallelToolUse.Exists() {
|
||||
parallelToolCalls = !disableParallelToolUse.Bool()
|
||||
}
|
||||
|
||||
// Add additional configuration parameters for the Codex API.
|
||||
template, _ = sjson.Set(template, "parallel_tool_calls", true)
|
||||
template, _ = sjson.SetBytes(template, "parallel_tool_calls", parallelToolCalls)
|
||||
|
||||
// Convert thinking.budget_tokens to reasoning.effort.
|
||||
reasoningEffort := "medium"
|
||||
@@ -294,13 +309,13 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
|
||||
template, _ = sjson.Set(template, "reasoning.summary", "auto")
|
||||
template, _ = sjson.Set(template, "stream", true)
|
||||
template, _ = sjson.Set(template, "store", false)
|
||||
template, _ = sjson.Set(template, "include", []string{"reasoning.encrypted_content"})
|
||||
template, _ = sjson.SetBytes(template, "reasoning.effort", reasoningEffort)
|
||||
template, _ = sjson.SetBytes(template, "reasoning.summary", "auto")
|
||||
template, _ = sjson.SetBytes(template, "stream", true)
|
||||
template, _ = sjson.SetBytes(template, "store", false)
|
||||
template, _ = sjson.SetBytes(template, "include", []string{"reasoning.encrypted_content"})
|
||||
|
||||
return []byte(template)
|
||||
return template
|
||||
}
|
||||
|
||||
// shortenNameIfNeeded applies a simple shortening rule for a single name.
|
||||
@@ -403,15 +418,15 @@ func normalizeToolParameters(raw string) string {
|
||||
if raw == "" || raw == "null" || !gjson.Valid(raw) {
|
||||
return `{"type":"object","properties":{}}`
|
||||
}
|
||||
schema := raw
|
||||
result := gjson.Parse(raw)
|
||||
schema := []byte(raw)
|
||||
schemaType := result.Get("type").String()
|
||||
if schemaType == "" {
|
||||
schema, _ = sjson.Set(schema, "type", "object")
|
||||
schema, _ = sjson.SetBytes(schema, "type", "object")
|
||||
schemaType = "object"
|
||||
}
|
||||
if schemaType == "object" && !result.Get("properties").Exists() {
|
||||
schema, _ = sjson.SetRaw(schema, "properties", `{}`)
|
||||
schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
|
||||
}
|
||||
return schema
|
||||
return string(schema)
|
||||
}
|
||||
|
||||
135
internal/translator/codex/claude/codex_claude_request_test.go
Normal file
135
internal/translator/codex/claude/codex_claude_request_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantHasDeveloper bool
|
||||
wantTexts []string
|
||||
}{
|
||||
{
|
||||
name: "No system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasDeveloper: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": "",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasDeveloper: false,
|
||||
},
|
||||
{
|
||||
name: "String system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": "Be helpful",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasDeveloper: true,
|
||||
wantTexts: []string{"Be helpful"},
|
||||
},
|
||||
{
|
||||
name: "Array system field with filtered billing header",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": [
|
||||
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
|
||||
{"type": "text", "text": "Block 1"},
|
||||
{"type": "text", "text": "Block 2"}
|
||||
],
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasDeveloper: true,
|
||||
wantTexts: []string{"Block 1", "Block 2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
inputs := resultJSON.Get("input").Array()
|
||||
|
||||
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
|
||||
if hasDeveloper != tt.wantHasDeveloper {
|
||||
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
|
||||
}
|
||||
|
||||
if !tt.wantHasDeveloper {
|
||||
return
|
||||
}
|
||||
|
||||
content := inputs[0].Get("content").Array()
|
||||
if len(content) != len(tt.wantTexts) {
|
||||
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
|
||||
}
|
||||
|
||||
for i, wantText := range tt.wantTexts {
|
||||
if gotType := content[i].Get("type").String(); gotType != "input_text" {
|
||||
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
|
||||
}
|
||||
if gotText := content[i].Get("text").String(); gotText != wantText {
|
||||
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToCodex_ParallelToolCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantParallelToolCalls bool
|
||||
}{
|
||||
{
|
||||
name: "Default to true when tool_choice.disable_parallel_tool_use is absent",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantParallelToolCalls: true,
|
||||
},
|
||||
{
|
||||
name: "Disable parallel tool calls when client opts out",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"tool_choice": {"disable_parallel_tool_use": true},
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantParallelToolCalls: false,
|
||||
},
|
||||
{
|
||||
name: "Keep parallel tool calls enabled when client explicitly allows them",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"tool_choice": {"disable_parallel_tool_use": false},
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantParallelToolCalls: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
if got := resultJSON.Get("parallel_tool_calls").Bool(); got != tt.wantParallelToolCalls {
|
||||
t.Fatalf("parallel_tool_calls = %v, want %v. Output: %s", got, tt.wantParallelToolCalls, string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,9 +9,10 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -42,8 +43,8 @@ type ConvertCodexResponseToClaudeParams struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
|
||||
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Claude Code-compatible JSON responses
|
||||
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertCodexResponseToClaudeParams{
|
||||
HasToolCall: false,
|
||||
@@ -53,95 +54,85 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
|
||||
// log.Debugf("rawJSON: %s", string(rawJSON))
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
output := ""
|
||||
output := make([]byte, 0, 512)
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
typeResult := rootResult.Get("type")
|
||||
typeStr := typeResult.String()
|
||||
template := ""
|
||||
var template []byte
|
||||
if typeStr == "response.created" {
|
||||
template = `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`
|
||||
template, _ = sjson.Set(template, "message.model", rootResult.Get("response.model").String())
|
||||
template, _ = sjson.Set(template, "message.id", rootResult.Get("response.id").String())
|
||||
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
|
||||
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
|
||||
template, _ = sjson.SetBytes(template, "message.id", rootResult.Get("response.id").String())
|
||||
|
||||
output = "event: message_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_part.added" {
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output = "event: content_block_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_text.delta" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
|
||||
|
||||
output = "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.reasoning_summary_part.done" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
|
||||
} else if typeStr == "response.content_part.added" {
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output = "event: content_block_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
} else if typeStr == "response.output_text.delta" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
|
||||
|
||||
output = "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.content_part.done" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
} else if typeStr == "response.completed" {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
||||
stopReason := rootResult.Get("response.stop_reason").String()
|
||||
if p {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
||||
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
|
||||
} else if stopReason == "max_tokens" || stopReason == "stop" {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||
template, _ = sjson.SetBytes(template, "delta.stop_reason", stopReason)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||
template, _ = sjson.SetBytes(template, "delta.stop_reason", "end_turn")
|
||||
}
|
||||
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
|
||||
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens)
|
||||
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.input_tokens", inputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
|
||||
output = "event: message_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output += "event: message_stop\n"
|
||||
output += `data: {"type":"message_stop"}`
|
||||
output += "\n\n"
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "message_delta", template, 2)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "message_stop", []byte(`{"type":"message_stop"}`), 2)
|
||||
} else if typeStr == "response.output_item.added" {
|
||||
itemResult := rootResult.Get("item")
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||
{
|
||||
// Restore original tool name if shortened
|
||||
name := itemResult.Get("name").String()
|
||||
@@ -149,37 +140,33 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
template, _ = sjson.Set(template, "content_block.name", name)
|
||||
template, _ = sjson.SetBytes(template, "content_block.name", name)
|
||||
}
|
||||
|
||||
output = "event: content_block_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
}
|
||||
} else if typeStr == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||
}
|
||||
} else if typeStr == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
} else if typeStr == "response.function_call_arguments.done" {
|
||||
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
|
||||
// in a single "done" event without preceding "delta" events.
|
||||
@@ -188,17 +175,16 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
// When delta events were already received, skip to avoid duplicating arguments.
|
||||
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
|
||||
if args := rootResult.Get("arguments").String(); args != "" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", args)
|
||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []string{output}
|
||||
return [][]byte{output}
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToClaudeNonStream converts a non-streaming Codex response to a non-streaming Claude Code response.
|
||||
@@ -213,28 +199,28 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Claude Code-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
|
||||
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
||||
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
responseData := rootResult.Get("response")
|
||||
if !responseData.Exists() {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
|
||||
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
out, _ = sjson.SetBytes(out, "id", responseData.Get("id").String())
|
||||
out, _ = sjson.SetBytes(out, "model", responseData.Get("model").String())
|
||||
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
|
||||
hasToolCall := false
|
||||
@@ -275,9 +261,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
}
|
||||
}
|
||||
if thinkingBuilder.Len() > 0 {
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
block := []byte(`{"type":"thinking","thinking":""}`)
|
||||
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
}
|
||||
case "message":
|
||||
if content := item.Get("content"); content.Exists() {
|
||||
@@ -286,9 +272,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
if part.Get("type").String() == "output_text" {
|
||||
text := part.Get("text").String()
|
||||
if text != "" {
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
block := []byte(`{"type":"text","text":""}`)
|
||||
block, _ = sjson.SetBytes(block, "text", text)
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -296,9 +282,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
} else {
|
||||
text := content.String()
|
||||
if text != "" {
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
block := []byte(`{"type":"text","text":""}`)
|
||||
block, _ = sjson.SetBytes(block, "text", text)
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -309,9 +295,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
name = original
|
||||
}
|
||||
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
|
||||
inputRaw := "{}"
|
||||
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
@@ -319,23 +305,23 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
inputRaw = argsJSON.Raw
|
||||
}
|
||||
}
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
|
||||
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
|
||||
out, _ = sjson.Set(out, "stop_reason", stopReason.String())
|
||||
out, _ = sjson.SetBytes(out, "stop_reason", stopReason.String())
|
||||
} else if hasToolCall {
|
||||
out, _ = sjson.Set(out, "stop_reason", "tool_use")
|
||||
out, _ = sjson.SetBytes(out, "stop_reason", "tool_use")
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||
out, _ = sjson.SetBytes(out, "stop_reason", "end_turn")
|
||||
}
|
||||
|
||||
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
|
||||
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "stop_sequence", []byte(stopSequence.Raw))
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -385,6 +371,6 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
|
||||
return rev
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.ClaudeInputTokensJSON(count)
|
||||
}
|
||||
|
||||
@@ -6,10 +6,9 @@ package geminiCLI
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
|
||||
"github.com/tidwall/sjson"
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
)
|
||||
|
||||
// ConvertCodexResponseToGeminiCLI converts Codex streaming response format to Gemini CLI format.
|
||||
@@ -24,14 +23,12 @@ import (
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
|
||||
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
outputs := ConvertCodexResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
newOutputs := make([]string, 0)
|
||||
newOutputs := make([][]byte, 0, len(outputs))
|
||||
for i := 0; i < len(outputs); i++ {
|
||||
json := `{"response": {}}`
|
||||
output, _ := sjson.SetRaw(json, "response", outputs[i])
|
||||
newOutputs = append(newOutputs, output)
|
||||
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
|
||||
}
|
||||
return newOutputs
|
||||
}
|
||||
@@ -47,15 +44,12 @@ func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, orig
|
||||
// - param: A pointer to a parameter object for the conversion
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
// log.Debug(string(rawJSON))
|
||||
strJSON := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
json := `{"response": {}}`
|
||||
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
|
||||
return strJSON
|
||||
// - []byte: A Gemini-compatible JSON response wrapped in a response object
|
||||
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
out := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
return translatorcommon.WrapGeminiCLIResponse(out)
|
||||
}
|
||||
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ import (
|
||||
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
// Base template
|
||||
out := `{"model":"","instructions":"","input":[]}`
|
||||
out := []byte(`{"model":"","instructions":"","input":[]}`)
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -82,24 +82,24 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
|
||||
// Model
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// System instruction -> as a user message with input_text parts
|
||||
sysParts := root.Get("system_instruction.parts")
|
||||
if sysParts.IsArray() {
|
||||
msg := `{"type":"message","role":"developer","content":[]}`
|
||||
msg := []byte(`{"type":"message","role":"developer","content":[]}`)
|
||||
arr := sysParts.Array()
|
||||
for i := 0; i < len(arr); i++ {
|
||||
p := arr[i]
|
||||
if t := p.Get("text"); t.Exists() {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_text")
|
||||
part, _ = sjson.Set(part, "text", t.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_text")
|
||||
part, _ = sjson.SetBytes(part, "text", t.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
}
|
||||
}
|
||||
if len(gjson.Get(msg, "content").Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
if len(gjson.GetBytes(msg, "content").Array()) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,23 +123,23 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
p := parr[j]
|
||||
// text part
|
||||
if t := p.Get("text"); t.Exists() {
|
||||
msg := `{"type":"message","role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg := []byte(`{"type":"message","role":"","content":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
partType := "input_text"
|
||||
if role == "assistant" {
|
||||
partType = "output_text"
|
||||
}
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", partType)
|
||||
part, _ = sjson.Set(part, "text", t.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", partType)
|
||||
part, _ = sjson.SetBytes(part, "text", t.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
|
||||
continue
|
||||
}
|
||||
|
||||
// function call from model
|
||||
if fc := p.Get("functionCall"); fc.Exists() {
|
||||
fn := `{"type":"function_call"}`
|
||||
fn := []byte(`{"type":"function_call"}`)
|
||||
if name := fc.Get("name"); name.Exists() {
|
||||
n := name.String()
|
||||
if short, ok := shortMap[n]; ok {
|
||||
@@ -147,31 +147,31 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
n = shortenNameIfNeeded(n)
|
||||
}
|
||||
fn, _ = sjson.Set(fn, "name", n)
|
||||
fn, _ = sjson.SetBytes(fn, "name", n)
|
||||
}
|
||||
if args := fc.Get("args"); args.Exists() {
|
||||
fn, _ = sjson.Set(fn, "arguments", args.Raw)
|
||||
fn, _ = sjson.SetBytes(fn, "arguments", args.Raw)
|
||||
}
|
||||
// generate a paired random call_id and enqueue it so the
|
||||
// corresponding functionResponse can pop the earliest id
|
||||
// to preserve ordering when multiple calls are present.
|
||||
id := genCallID()
|
||||
fn, _ = sjson.Set(fn, "call_id", id)
|
||||
fn, _ = sjson.SetBytes(fn, "call_id", id)
|
||||
pendingCallIDs = append(pendingCallIDs, id)
|
||||
out, _ = sjson.SetRaw(out, "input.-1", fn)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", fn)
|
||||
continue
|
||||
}
|
||||
|
||||
// function response from user
|
||||
if fr := p.Get("functionResponse"); fr.Exists() {
|
||||
fno := `{"type":"function_call_output"}`
|
||||
fno := []byte(`{"type":"function_call_output"}`)
|
||||
// Prefer a string result if present; otherwise embed the raw response as a string
|
||||
if res := fr.Get("response.result"); res.Exists() {
|
||||
fno, _ = sjson.Set(fno, "output", res.String())
|
||||
fno, _ = sjson.SetBytes(fno, "output", res.String())
|
||||
} else if resp := fr.Get("response"); resp.Exists() {
|
||||
fno, _ = sjson.Set(fno, "output", resp.Raw)
|
||||
fno, _ = sjson.SetBytes(fno, "output", resp.Raw)
|
||||
}
|
||||
// fno, _ = sjson.Set(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
|
||||
// fno, _ = sjson.SetBytes(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
|
||||
// attach the oldest queued call_id to pair the response
|
||||
// with its call. If the queue is empty, generate a new id.
|
||||
var id string
|
||||
@@ -182,8 +182,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
id = genCallID()
|
||||
}
|
||||
fno, _ = sjson.Set(fno, "call_id", id)
|
||||
out, _ = sjson.SetRaw(out, "input.-1", fno)
|
||||
fno, _ = sjson.SetBytes(fno, "call_id", id)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", fno)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -193,8 +193,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Tools mapping: Gemini functionDeclarations -> Codex tools
|
||||
tools := root.Get("tools")
|
||||
if tools.IsArray() {
|
||||
out, _ = sjson.SetRaw(out, "tools", `[]`)
|
||||
out, _ = sjson.Set(out, "tool_choice", "auto")
|
||||
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
|
||||
out, _ = sjson.SetBytes(out, "tool_choice", "auto")
|
||||
tarr := tools.Array()
|
||||
for i := 0; i < len(tarr); i++ {
|
||||
td := tarr[i]
|
||||
@@ -205,8 +205,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
farr := fns.Array()
|
||||
for j := 0; j < len(farr); j++ {
|
||||
fn := farr[j]
|
||||
tool := `{}`
|
||||
tool, _ = sjson.Set(tool, "type", "function")
|
||||
tool := []byte(`{}`)
|
||||
tool, _ = sjson.SetBytes(tool, "type", "function")
|
||||
if v := fn.Get("name"); v.Exists() {
|
||||
name := v.String()
|
||||
if short, ok := shortMap[name]; ok {
|
||||
@@ -214,32 +214,32 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
tool, _ = sjson.Set(tool, "name", name)
|
||||
tool, _ = sjson.SetBytes(tool, "name", name)
|
||||
}
|
||||
if v := fn.Get("description"); v.Exists() {
|
||||
tool, _ = sjson.Set(tool, "description", v.String())
|
||||
tool, _ = sjson.SetBytes(tool, "description", v.String())
|
||||
}
|
||||
if prm := fn.Get("parameters"); prm.Exists() {
|
||||
// Remove optional $schema field if present
|
||||
cleaned := prm.Raw
|
||||
cleaned, _ = sjson.Delete(cleaned, "$schema")
|
||||
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
|
||||
cleaned := []byte(prm.Raw)
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
|
||||
tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
|
||||
} else if prm = fn.Get("parametersJsonSchema"); prm.Exists() {
|
||||
// Remove optional $schema field if present
|
||||
cleaned := prm.Raw
|
||||
cleaned, _ = sjson.Delete(cleaned, "$schema")
|
||||
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
|
||||
cleaned := []byte(prm.Raw)
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
|
||||
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
|
||||
tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
|
||||
}
|
||||
tool, _ = sjson.Set(tool, "strict", false)
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", tool)
|
||||
tool, _ = sjson.SetBytes(tool, "strict", false)
|
||||
out, _ = sjson.SetRawBytes(out, "tools.-1", tool)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fixed flags aligning with Codex expectations
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
|
||||
|
||||
// Convert Gemini thinkingConfig to Codex reasoning.effort.
|
||||
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
|
||||
@@ -253,7 +253,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
if thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
} else {
|
||||
@@ -263,7 +263,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
}
|
||||
@@ -272,22 +272,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
if !effortSet {
|
||||
// No thinking config, set default effort
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "medium")
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
|
||||
}
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.Set(out, "stream", true)
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
|
||||
out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.SetBytes(out, "stream", true)
|
||||
out, _ = sjson.SetBytes(out, "store", false)
|
||||
out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
|
||||
|
||||
var pathsToLower []string
|
||||
toolsResult := gjson.Get(out, "tools")
|
||||
toolsResult := gjson.GetBytes(out, "tools")
|
||||
util.Walk(toolsResult, "", "type", &pathsToLower)
|
||||
for _, p := range pathsToLower {
|
||||
fullPath := fmt.Sprintf("tools.%s", p)
|
||||
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
|
||||
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// shortenNameIfNeeded applies the simple shortening rule for a single name.
|
||||
|
||||
@@ -7,9 +7,9 @@ package gemini
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -23,7 +23,7 @@ type ConvertCodexResponseToGeminiParams struct {
|
||||
Model string
|
||||
CreatedAt int64
|
||||
ResponseID string
|
||||
LastStorageOutput string
|
||||
LastStorageOutput []byte
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
|
||||
@@ -38,19 +38,19 @@ type ConvertCodexResponseToGeminiParams struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
|
||||
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of Gemini-compatible JSON responses
|
||||
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertCodexResponseToGeminiParams{
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
LastStorageOutput: "",
|
||||
LastStorageOutput: nil,
|
||||
}
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
@@ -59,17 +59,17 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
typeStr := typeResult.String()
|
||||
|
||||
// Base Gemini response template
|
||||
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`
|
||||
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" && typeStr == "response.output_item.done" {
|
||||
template = (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput
|
||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`)
|
||||
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 && typeStr == "response.output_item.done" {
|
||||
template = append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
|
||||
createdAtResult := rootResult.Get("response.created_at")
|
||||
if createdAtResult.Exists() {
|
||||
(*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int()
|
||||
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
||||
}
|
||||
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
|
||||
}
|
||||
|
||||
// Handle function call completion
|
||||
@@ -78,7 +78,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
// Create function call part
|
||||
functionCall := `{"functionCall":{"name":"","args":{}}}`
|
||||
functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
|
||||
{
|
||||
// Restore original tool name if shortened
|
||||
n := itemResult.Get("name").String()
|
||||
@@ -86,7 +86,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
if orig, ok := rev[n]; ok {
|
||||
n = orig
|
||||
}
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
|
||||
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
|
||||
}
|
||||
|
||||
// Parse and set arguments
|
||||
@@ -94,47 +94,48 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
if argsStr != "" {
|
||||
argsResult := gjson.Parse(argsStr)
|
||||
if argsResult.IsObject() {
|
||||
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
|
||||
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
|
||||
}
|
||||
}
|
||||
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
|
||||
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = template
|
||||
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
|
||||
|
||||
// Use this return to storage message
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
}
|
||||
|
||||
if typeStr == "response.created" { // Handle response creation - set model and response ID
|
||||
template, _ = sjson.Set(template, "modelVersion", rootResult.Get("response.model").String())
|
||||
template, _ = sjson.Set(template, "responseId", rootResult.Get("response.id").String())
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String())
|
||||
template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String())
|
||||
(*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String()
|
||||
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
|
||||
part := `{"thought":true,"text":""}`
|
||||
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
part := []byte(`{"thought":true,"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||
} else if typeStr == "response.completed" { // Handle response completion with usage metadata
|
||||
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
|
||||
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
|
||||
totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int()
|
||||
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||
} else {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" {
|
||||
return []string{(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput, template}
|
||||
} else {
|
||||
return []string{template}
|
||||
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 {
|
||||
return [][]byte{
|
||||
append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...),
|
||||
template,
|
||||
}
|
||||
}
|
||||
|
||||
return [][]byte{template}
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToGeminiNonStream converts a non-streaming Codex response to a non-streaming Gemini response.
|
||||
@@ -149,32 +150,32 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Gemini-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Verify this is a response.completed event
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
// Base Gemini response template for non-streaming
|
||||
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
|
||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
|
||||
|
||||
// Set model version
|
||||
template, _ = sjson.Set(template, "modelVersion", modelName)
|
||||
template, _ = sjson.SetBytes(template, "modelVersion", modelName)
|
||||
|
||||
// Set response metadata from the completed response
|
||||
responseData := rootResult.Get("response")
|
||||
if responseData.Exists() {
|
||||
// Set response ID
|
||||
if responseId := responseData.Get("id"); responseId.Exists() {
|
||||
template, _ = sjson.Set(template, "responseId", responseId.String())
|
||||
template, _ = sjson.SetBytes(template, "responseId", responseId.String())
|
||||
}
|
||||
|
||||
// Set creation time
|
||||
if createdAt := responseData.Get("created_at"); createdAt.Exists() {
|
||||
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
|
||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
|
||||
}
|
||||
|
||||
// Set usage metadata
|
||||
@@ -183,14 +184,14 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
totalTokens := inputTokens + outputTokens
|
||||
|
||||
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||
}
|
||||
|
||||
// Process output content to build parts array
|
||||
hasToolCall := false
|
||||
var pendingFunctionCalls []string
|
||||
var pendingFunctionCalls [][]byte
|
||||
|
||||
flushPendingFunctionCalls := func() {
|
||||
if len(pendingFunctionCalls) == 0 {
|
||||
@@ -199,7 +200,7 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
// Add all pending function calls as individual parts
|
||||
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
|
||||
for _, fc := range pendingFunctionCalls {
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", fc)
|
||||
}
|
||||
pendingFunctionCalls = nil
|
||||
}
|
||||
@@ -215,9 +216,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Add thinking content
|
||||
if content := value.Get("content"); content.Exists() {
|
||||
part := `{"text":"","thought":true}`
|
||||
part, _ = sjson.Set(part, "text", content.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
part := []byte(`{"text":"","thought":true}`)
|
||||
part, _ = sjson.SetBytes(part, "text", content.String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||
}
|
||||
|
||||
case "message":
|
||||
@@ -229,9 +230,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
content.ForEach(func(_, contentItem gjson.Result) bool {
|
||||
if contentItem.Get("type").String() == "output_text" {
|
||||
if text := contentItem.Get("text"); text.Exists() {
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", text.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", text.String())
|
||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -241,21 +242,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
case "function_call":
|
||||
// Collect function call for potential merging with consecutive ones
|
||||
hasToolCall = true
|
||||
functionCall := `{"functionCall":{"args":{},"name":""}}`
|
||||
functionCall := []byte(`{"functionCall":{"args":{},"name":""}}`)
|
||||
{
|
||||
n := value.Get("name").String()
|
||||
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
|
||||
if orig, ok := rev[n]; ok {
|
||||
n = orig
|
||||
}
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
|
||||
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
|
||||
}
|
||||
|
||||
// Parse and set arguments
|
||||
if argsStr := value.Get("arguments").String(); argsStr != "" {
|
||||
argsResult := gjson.Parse(argsStr)
|
||||
if argsResult.IsObject() {
|
||||
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
|
||||
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,9 +271,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Set finish reason based on whether there were tool calls
|
||||
if hasToolCall {
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||
}
|
||||
}
|
||||
return template
|
||||
@@ -307,6 +308,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
|
||||
return rev
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
@@ -29,42 +29,42 @@ import (
|
||||
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
// Start with empty JSON object
|
||||
out := `{"instructions":""}`
|
||||
out := []byte(`{"instructions":""}`)
|
||||
|
||||
// Stream must be set to true
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
out, _ = sjson.SetBytes(out, "stream", stream)
|
||||
|
||||
// Codex not support temperature, top_p, top_k, max_output_tokens, so comment them
|
||||
// if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "temperature", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "temperature", v.Value())
|
||||
// }
|
||||
// if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "top_p", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "top_p", v.Value())
|
||||
// }
|
||||
// if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "top_k", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "top_k", v.Value())
|
||||
// }
|
||||
|
||||
// Map token limits
|
||||
// if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
|
||||
// }
|
||||
// if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() {
|
||||
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
|
||||
// out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
|
||||
// }
|
||||
|
||||
// Map reasoning effort
|
||||
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", v.Value())
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "medium")
|
||||
out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
|
||||
}
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
|
||||
out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
|
||||
|
||||
// Model
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// Build tool name shortening map from original tools (if any)
|
||||
originalToolNameMap := map[string]string{}
|
||||
@@ -100,9 +100,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
// if m.Get("role").String() == "system" {
|
||||
// c := m.Get("content")
|
||||
// if c.Type == gjson.String {
|
||||
// out, _ = sjson.Set(out, "instructions", c.String())
|
||||
// out, _ = sjson.SetBytes(out, "instructions", c.String())
|
||||
// } else if c.IsObject() && c.Get("type").String() == "text" {
|
||||
// out, _ = sjson.Set(out, "instructions", c.Get("text").String())
|
||||
// out, _ = sjson.SetBytes(out, "instructions", c.Get("text").String())
|
||||
// }
|
||||
// break
|
||||
// }
|
||||
@@ -110,7 +110,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
// }
|
||||
|
||||
// Build input from messages, handling all message types including tool calls
|
||||
out, _ = sjson.SetRaw(out, "input", `[]`)
|
||||
out, _ = sjson.SetRawBytes(out, "input", []byte(`[]`))
|
||||
if messages.IsArray() {
|
||||
arr := messages.Array()
|
||||
for i := 0; i < len(arr); i++ {
|
||||
@@ -124,23 +124,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
content := m.Get("content").String()
|
||||
|
||||
// Create function_call_output object
|
||||
funcOutput := `{}`
|
||||
funcOutput, _ = sjson.Set(funcOutput, "type", "function_call_output")
|
||||
funcOutput, _ = sjson.Set(funcOutput, "call_id", toolCallID)
|
||||
funcOutput, _ = sjson.Set(funcOutput, "output", content)
|
||||
out, _ = sjson.SetRaw(out, "input.-1", funcOutput)
|
||||
funcOutput := []byte(`{}`)
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
|
||||
|
||||
default:
|
||||
// Handle regular messages
|
||||
msg := `{}`
|
||||
msg, _ = sjson.Set(msg, "type", "message")
|
||||
msg := []byte(`{}`)
|
||||
msg, _ = sjson.SetBytes(msg, "type", "message")
|
||||
if role == "system" {
|
||||
msg, _ = sjson.Set(msg, "role", "developer")
|
||||
msg, _ = sjson.SetBytes(msg, "role", "developer")
|
||||
} else {
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
msg, _ = sjson.SetBytes(msg, "role", role)
|
||||
}
|
||||
|
||||
msg, _ = sjson.SetRaw(msg, "content", `[]`)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content", []byte(`[]`))
|
||||
|
||||
// Handle regular content
|
||||
c := m.Get("content")
|
||||
@@ -150,10 +150,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
if role == "assistant" {
|
||||
partType = "output_text"
|
||||
}
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", partType)
|
||||
part, _ = sjson.Set(part, "text", c.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", partType)
|
||||
part, _ = sjson.SetBytes(part, "text", c.String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
} else if c.Exists() && c.IsArray() {
|
||||
items := c.Array()
|
||||
for j := 0; j < len(items); j++ {
|
||||
@@ -165,39 +165,44 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
if role == "assistant" {
|
||||
partType = "output_text"
|
||||
}
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", partType)
|
||||
part, _ = sjson.Set(part, "text", it.Get("text").String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", partType)
|
||||
part, _ = sjson.SetBytes(part, "text", it.Get("text").String())
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
case "image_url":
|
||||
// Map image inputs to input_image for Responses API
|
||||
if role == "user" {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_image")
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_image")
|
||||
if u := it.Get("image_url.url"); u.Exists() {
|
||||
part, _ = sjson.Set(part, "image_url", u.String())
|
||||
part, _ = sjson.SetBytes(part, "image_url", u.String())
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
}
|
||||
case "file":
|
||||
if role == "user" {
|
||||
fileData := it.Get("file.file_data").String()
|
||||
filename := it.Get("file.filename").String()
|
||||
if fileData != "" {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_file")
|
||||
part, _ = sjson.Set(part, "file_data", fileData)
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_file")
|
||||
part, _ = sjson.SetBytes(part, "file_data", fileData)
|
||||
if filename != "" {
|
||||
part, _ = sjson.Set(part, "filename", filename)
|
||||
part, _ = sjson.SetBytes(part, "filename", filename)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||
// Don't emit empty assistant messages when only tool_calls
|
||||
// are present — Responses API needs function_call items
|
||||
// directly, otherwise call_id matching fails (#2132).
|
||||
if role != "assistant" || len(gjson.GetBytes(msg, "content").Array()) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
|
||||
}
|
||||
|
||||
// Handle tool calls for assistant messages as separate top-level objects
|
||||
if role == "assistant" {
|
||||
@@ -208,9 +213,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
tc := toolCallsArr[j]
|
||||
if tc.Get("type").String() == "function" {
|
||||
// Create function_call as top-level object
|
||||
funcCall := `{}`
|
||||
funcCall, _ = sjson.Set(funcCall, "type", "function_call")
|
||||
funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String())
|
||||
funcCall := []byte(`{}`)
|
||||
funcCall, _ = sjson.SetBytes(funcCall, "type", "function_call")
|
||||
funcCall, _ = sjson.SetBytes(funcCall, "call_id", tc.Get("id").String())
|
||||
{
|
||||
name := tc.Get("function.name").String()
|
||||
if short, ok := originalToolNameMap[name]; ok {
|
||||
@@ -218,10 +223,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
funcCall, _ = sjson.Set(funcCall, "name", name)
|
||||
funcCall, _ = sjson.SetBytes(funcCall, "name", name)
|
||||
}
|
||||
funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String())
|
||||
out, _ = sjson.SetRaw(out, "input.-1", funcCall)
|
||||
funcCall, _ = sjson.SetBytes(funcCall, "arguments", tc.Get("function.arguments").String())
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", funcCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -235,26 +240,26 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
text := gjson.GetBytes(rawJSON, "text")
|
||||
if rf.Exists() {
|
||||
// Always create text object when response_format provided
|
||||
if !gjson.Get(out, "text").Exists() {
|
||||
out, _ = sjson.SetRaw(out, "text", `{}`)
|
||||
if !gjson.GetBytes(out, "text").Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
|
||||
}
|
||||
|
||||
rft := rf.Get("type").String()
|
||||
switch rft {
|
||||
case "text":
|
||||
out, _ = sjson.Set(out, "text.format.type", "text")
|
||||
out, _ = sjson.SetBytes(out, "text.format.type", "text")
|
||||
case "json_schema":
|
||||
js := rf.Get("json_schema")
|
||||
if js.Exists() {
|
||||
out, _ = sjson.Set(out, "text.format.type", "json_schema")
|
||||
out, _ = sjson.SetBytes(out, "text.format.type", "json_schema")
|
||||
if v := js.Get("name"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "text.format.name", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text.format.name", v.Value())
|
||||
}
|
||||
if v := js.Get("strict"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "text.format.strict", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text.format.strict", v.Value())
|
||||
}
|
||||
if v := js.Get("schema"); v.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "text.format.schema", []byte(v.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,23 +267,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
// Map verbosity if provided
|
||||
if text.Exists() {
|
||||
if v := text.Get("verbosity"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "text.verbosity", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
|
||||
}
|
||||
}
|
||||
} else if text.Exists() {
|
||||
// If only text.verbosity present (no response_format), map verbosity
|
||||
if v := text.Get("verbosity"); v.Exists() {
|
||||
if !gjson.Get(out, "text").Exists() {
|
||||
out, _ = sjson.SetRaw(out, "text", `{}`)
|
||||
if !gjson.GetBytes(out, "text").Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
|
||||
}
|
||||
out, _ = sjson.Set(out, "text.verbosity", v.Value())
|
||||
out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
|
||||
}
|
||||
}
|
||||
|
||||
// Map tools (flatten function fields)
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||
out, _ = sjson.SetRaw(out, "tools", `[]`)
|
||||
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
|
||||
arr := tools.Array()
|
||||
for i := 0; i < len(arr); i++ {
|
||||
t := arr[i]
|
||||
@@ -286,13 +291,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
|
||||
// Only "function" needs structural conversion because Chat Completions nests details under "function".
|
||||
if toolType != "" && toolType != "function" && t.IsObject() {
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "tools.-1", []byte(t.Raw))
|
||||
continue
|
||||
}
|
||||
|
||||
if toolType == "function" {
|
||||
item := `{}`
|
||||
item, _ = sjson.Set(item, "type", "function")
|
||||
item := []byte(`{}`)
|
||||
item, _ = sjson.SetBytes(item, "type", "function")
|
||||
fn := t.Get("function")
|
||||
if fn.Exists() {
|
||||
if v := fn.Get("name"); v.Exists() {
|
||||
@@ -302,19 +307,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
item, _ = sjson.Set(item, "name", name)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
}
|
||||
if v := fn.Get("description"); v.Exists() {
|
||||
item, _ = sjson.Set(item, "description", v.Value())
|
||||
item, _ = sjson.SetBytes(item, "description", v.Value())
|
||||
}
|
||||
if v := fn.Get("parameters"); v.Exists() {
|
||||
item, _ = sjson.SetRaw(item, "parameters", v.Raw)
|
||||
item, _ = sjson.SetRawBytes(item, "parameters", []byte(v.Raw))
|
||||
}
|
||||
if v := fn.Get("strict"); v.Exists() {
|
||||
item, _ = sjson.Set(item, "strict", v.Value())
|
||||
item, _ = sjson.SetBytes(item, "strict", v.Value())
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", item)
|
||||
out, _ = sjson.SetRawBytes(out, "tools.-1", item)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -325,7 +330,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
|
||||
switch {
|
||||
case tc.Type == gjson.String:
|
||||
out, _ = sjson.Set(out, "tool_choice", tc.String())
|
||||
out, _ = sjson.SetBytes(out, "tool_choice", tc.String())
|
||||
case tc.IsObject():
|
||||
tcType := tc.Get("type").String()
|
||||
if tcType == "function" {
|
||||
@@ -337,21 +342,21 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
}
|
||||
choice := `{}`
|
||||
choice, _ = sjson.Set(choice, "type", "function")
|
||||
choice := []byte(`{}`)
|
||||
choice, _ = sjson.SetBytes(choice, "type", "function")
|
||||
if name != "" {
|
||||
choice, _ = sjson.Set(choice, "name", name)
|
||||
choice, _ = sjson.SetBytes(choice, "name", name)
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", choice)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", choice)
|
||||
} else if tcType != "" {
|
||||
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(tc.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
return []byte(out)
|
||||
out, _ = sjson.SetBytes(out, "store", false)
|
||||
return out
|
||||
}
|
||||
|
||||
// shortenNameIfNeeded applies the simple shortening rule for a single name.
|
||||
|
||||
@@ -0,0 +1,635 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Basic tool-call: system + user + assistant(tool_calls, no content) + tool result.
|
||||
// Expects developer msg + user msg + function_call + function_call_output.
|
||||
// No empty assistant message should appear between user and function_call.
|
||||
func TestToolCallSimple(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Paris\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "sunny, 22C"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
// system -> developer
|
||||
if items[0].Get("type").String() != "message" {
|
||||
t.Errorf("item 0: expected type 'message', got '%s'", items[0].Get("type").String())
|
||||
}
|
||||
if items[0].Get("role").String() != "developer" {
|
||||
t.Errorf("item 0: expected role 'developer', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
// user
|
||||
if items[1].Get("type").String() != "message" {
|
||||
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("role").String() != "user" {
|
||||
t.Errorf("item 1: expected role 'user', got '%s'", items[1].Get("role").String())
|
||||
}
|
||||
|
||||
// function_call, not an empty assistant msg
|
||||
if items[2].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
|
||||
}
|
||||
if items[2].Get("call_id").String() != "call_1" {
|
||||
t.Errorf("item 2: expected call_id 'call_1', got '%s'", items[2].Get("call_id").String())
|
||||
}
|
||||
if items[2].Get("name").String() != "get_weather" {
|
||||
t.Errorf("item 2: expected name 'get_weather', got '%s'", items[2].Get("name").String())
|
||||
}
|
||||
if items[2].Get("arguments").String() != `{"city":"Paris"}` {
|
||||
t.Errorf("item 2: unexpected arguments: %s", items[2].Get("arguments").String())
|
||||
}
|
||||
|
||||
// function_call_output
|
||||
if items[3].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
|
||||
}
|
||||
if items[3].Get("call_id").String() != "call_1" {
|
||||
t.Errorf("item 3: expected call_id 'call_1', got '%s'", items[3].Get("call_id").String())
|
||||
}
|
||||
if items[3].Get("output").String() != "sunny, 22C" {
|
||||
t.Errorf("item 3: expected output 'sunny, 22C', got '%s'", items[3].Get("output").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Assistant has both text content and tool_calls — the message should
|
||||
// be emitted (non-empty content), followed by function_call items.
|
||||
func TestToolCallWithContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the weather for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "rainy, 15C"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user + assistant(with content) + function_call + function_call_output
|
||||
if len(items) != 4 {
|
||||
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
if items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
// assistant with content — should be kept
|
||||
if items[1].Get("type").String() != "message" {
|
||||
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("role").String() != "assistant" {
|
||||
t.Errorf("item 1: expected role 'assistant', got '%s'", items[1].Get("role").String())
|
||||
}
|
||||
contentParts := items[1].Get("content").Array()
|
||||
if len(contentParts) == 0 {
|
||||
t.Errorf("item 1: assistant message should have content parts")
|
||||
}
|
||||
|
||||
if items[2].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
|
||||
}
|
||||
if items[2].Get("call_id").String() != "call_abc" {
|
||||
t.Errorf("item 2: expected call_id 'call_abc', got '%s'", items[2].Get("call_id").String())
|
||||
}
|
||||
|
||||
if items[3].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
|
||||
}
|
||||
if items[3].Get("call_id").String() != "call_abc" {
|
||||
t.Errorf("item 3: expected call_id 'call_abc', got '%s'", items[3].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
|
||||
// and outputs must be translated and paired correctly.
|
||||
func TestMultipleToolCalls(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Compare weather in Paris, London and Tokyo"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_paris",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Paris\"}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "call_london",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"London\"}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "call_tokyo",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"Tokyo\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_paris", "content": "sunny, 22C"},
|
||||
{"role": "tool", "tool_call_id": "call_london", "content": "cloudy, 14C"},
|
||||
{"role": "tool", "tool_call_id": "call_tokyo", "content": "humid, 28C"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user + 3 function_call + 3 function_call_output = 7
|
||||
if len(items) != 7 {
|
||||
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
if items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
|
||||
}
|
||||
|
||||
expectedCallIDs := []string{"call_paris", "call_london", "call_tokyo"}
|
||||
for i, expectedID := range expectedCallIDs {
|
||||
idx := i + 1
|
||||
if items[idx].Get("type").String() != "function_call" {
|
||||
t.Errorf("item %d: expected type 'function_call', got '%s'", idx, items[idx].Get("type").String())
|
||||
}
|
||||
if items[idx].Get("call_id").String() != expectedID {
|
||||
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedID, items[idx].Get("call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
expectedOutputs := []string{"sunny, 22C", "cloudy, 14C", "humid, 28C"}
|
||||
for i, expectedOutput := range expectedOutputs {
|
||||
idx := i + 4
|
||||
if items[idx].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item %d: expected type 'function_call_output', got '%s'", idx, items[idx].Get("type").String())
|
||||
}
|
||||
if items[idx].Get("call_id").String() != expectedCallIDs[i] {
|
||||
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedCallIDs[i], items[idx].Get("call_id").String())
|
||||
}
|
||||
if items[idx].Get("output").String() != expectedOutput {
|
||||
t.Errorf("item %d: expected output '%s', got '%s'", idx, expectedOutput, items[idx].Get("output").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Regression test for #2132: tool-call-only assistant messages (content:null)
|
||||
// must not produce an empty message item in the translated output.
|
||||
func TestNoSpuriousEmptyAssistantMessage(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Call a tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_x",
|
||||
"type": "function",
|
||||
"function": {"name": "do_thing", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_x", "content": "done"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "do_thing",
|
||||
"description": "Do a thing",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
for i, item := range items {
|
||||
typ := item.Get("type").String()
|
||||
role := item.Get("role").String()
|
||||
if typ == "message" && role == "assistant" {
|
||||
contentArr := item.Get("content").Array()
|
||||
if len(contentArr) == 0 {
|
||||
t.Errorf("item %d: empty assistant message breaks call_id matching. item: %s", i, item.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should be exactly: user + function_call + function_call_output
|
||||
if len(items) != 3 {
|
||||
t.Fatalf("expected 3 input items (user + function_call + function_call_output), got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
if items[0].Get("type").String() != "message" || items[0].Get("role").String() != "user" {
|
||||
t.Errorf("item 0: expected user message")
|
||||
}
|
||||
if items[1].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
|
||||
}
|
||||
if items[2].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Two rounds of tool calling in one conversation, with a text reply in between.
|
||||
func TestMultiTurnToolCalling(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{"id": "call_r1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_r1", "content": "sunny"},
|
||||
{"role": "assistant", "content": "It is sunny in Paris."},
|
||||
{"role": "user", "content": "And London?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{"id": "call_r2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"London\"}"}}]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_r2", "content": "rainy"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
// user, func_call(r1), func_output(r1), assistant text, user, func_call(r2), func_output(r2)
|
||||
if len(items) != 7 {
|
||||
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
|
||||
if len(item.Get("content").Array()) == 0 {
|
||||
t.Errorf("item %d: unexpected empty assistant message", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// round 1
|
||||
if items[1].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
|
||||
}
|
||||
if items[1].Get("call_id").String() != "call_r1" {
|
||||
t.Errorf("item 1: expected call_id 'call_r1', got '%s'", items[1].Get("call_id").String())
|
||||
}
|
||||
if items[2].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
|
||||
}
|
||||
|
||||
// text reply between rounds
|
||||
if items[3].Get("type").String() != "message" || items[3].Get("role").String() != "assistant" {
|
||||
t.Errorf("item 3: expected assistant message, got type=%s role=%s", items[3].Get("type").String(), items[3].Get("role").String())
|
||||
}
|
||||
|
||||
// round 2
|
||||
if items[5].Get("type").String() != "function_call" {
|
||||
t.Errorf("item 5: expected function_call, got %s", items[5].Get("type").String())
|
||||
}
|
||||
if items[5].Get("call_id").String() != "call_r2" {
|
||||
t.Errorf("item 5: expected call_id 'call_r2', got '%s'", items[5].Get("call_id").String())
|
||||
}
|
||||
if items[6].Get("type").String() != "function_call_output" {
|
||||
t.Errorf("item 6: expected function_call_output, got %s", items[6].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Tool names over 64 chars get shortened, call_id stays the same.
|
||||
func TestToolNameShortening(t *testing.T) {
|
||||
longName := "a_very_long_tool_name_that_exceeds_sixty_four_characters_limit_here_test"
|
||||
if len(longName) <= 64 {
|
||||
t.Fatalf("test setup error: name must be > 64 chars, got %d", len(longName))
|
||||
}
|
||||
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do it"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_long",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "` + longName + `",
|
||||
"arguments": "{}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_long", "content": "ok"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "` + longName + `",
|
||||
"description": "A tool with a very long name",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
// find function_call
|
||||
var funcCallItem gjson.Result
|
||||
for _, item := range items {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
funcCallItem = item
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !funcCallItem.Exists() {
|
||||
t.Fatal("no function_call item found in output")
|
||||
}
|
||||
|
||||
// call_id unchanged
|
||||
if funcCallItem.Get("call_id").String() != "call_long" {
|
||||
t.Errorf("call_id changed: expected 'call_long', got '%s'", funcCallItem.Get("call_id").String())
|
||||
}
|
||||
|
||||
// name must be truncated
|
||||
translatedName := funcCallItem.Get("name").String()
|
||||
if translatedName == longName {
|
||||
t.Errorf("tool name was NOT shortened: still '%s'", translatedName)
|
||||
}
|
||||
if len(translatedName) > 64 {
|
||||
t.Errorf("shortened name still > 64 chars: len=%d name='%s'", len(translatedName), translatedName)
|
||||
}
|
||||
}
|
||||
|
||||
// content:"" (empty string, not null) should be treated the same as null.
|
||||
func TestEmptyStringContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do something"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_empty",
|
||||
"type": "function",
|
||||
"function": {"name": "action", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_empty", "content": "result"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "action",
|
||||
"description": "An action",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
|
||||
if len(item.Get("content").Array()) == 0 {
|
||||
t.Errorf("item %d: empty assistant message from content:\"\"", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// user + function_call + function_call_output
|
||||
if len(items) != 3 {
|
||||
t.Errorf("expected 3 input items, got %d", len(items))
|
||||
}
|
||||
}
|
||||
|
||||
// Every function_call_output must have a matching function_call by call_id.
|
||||
func TestCallIDsMatchBetweenCallAndOutput(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Multi-tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{"id": "id_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}},
|
||||
{"id": "id_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}
|
||||
]
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "id_a", "content": "res_a"},
|
||||
{"role": "tool", "tool_call_id": "id_b", "content": "res_b"}
|
||||
],
|
||||
"tools": [
|
||||
{"type": "function", "function": {"name": "tool_a", "description": "A", "parameters": {"type": "object", "properties": {}}}},
|
||||
{"type": "function", "function": {"name": "tool_b", "description": "B", "parameters": {"type": "object", "properties": {}}}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
items := gjson.Get(result, "input").Array()
|
||||
|
||||
// collect call_ids from function_call items
|
||||
callIDs := make(map[string]bool)
|
||||
for _, item := range items {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
callIDs[item.Get("call_id").String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if item.Get("type").String() == "function_call_output" {
|
||||
outID := item.Get("call_id").String()
|
||||
if !callIDs[outID] {
|
||||
t.Errorf("item %d: function_call_output has call_id '%s' with no matching function_call", i, outID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2 calls, 2 outputs
|
||||
funcCallCount := 0
|
||||
funcOutputCount := 0
|
||||
for _, item := range items {
|
||||
switch item.Get("type").String() {
|
||||
case "function_call":
|
||||
funcCallCount++
|
||||
case "function_call_output":
|
||||
funcOutputCount++
|
||||
}
|
||||
}
|
||||
if funcCallCount != 2 {
|
||||
t.Errorf("expected 2 function_calls, got %d", funcCallCount)
|
||||
}
|
||||
if funcOutputCount != 2 {
|
||||
t.Errorf("expected 2 function_call_outputs, got %d", funcOutputCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Tools array should carry over to the Responses format output.
|
||||
func TestToolsDefinitionTranslated(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hi"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
tools := gjson.Get(result, "tools").Array()
|
||||
if len(tools) == 0 {
|
||||
t.Fatal("no tools found in output")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, tool := range tools {
|
||||
if tool.Get("name").String() == "search" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("tool 'search' not found in output tools: %s", gjson.Get(result, "tools").Raw)
|
||||
}
|
||||
}
|
||||
@@ -41,8 +41,8 @@ type ConvertCliToOpenAIParams struct {
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of OpenAI-compatible JSON responses
|
||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &ConvertCliToOpenAIParams{
|
||||
Model: modelName,
|
||||
@@ -55,12 +55,12 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -70,67 +70,67 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
(*param).(*ConvertCliToOpenAIParams).ResponseID = rootResult.Get("response.id").String()
|
||||
(*param).(*ConvertCliToOpenAIParams).CreatedAt = rootResult.Get("response.created_at").Int()
|
||||
(*param).(*ConvertCliToOpenAIParams).Model = rootResult.Get("response.model").String()
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Extract and set the model version.
|
||||
cachedModel := (*param).(*ConvertCliToOpenAIParams).Model
|
||||
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||
template, _ = sjson.SetBytes(template, "model", modelResult.String())
|
||||
} else if cachedModel != "" {
|
||||
template, _ = sjson.Set(template, "model", cachedModel)
|
||||
template, _ = sjson.SetBytes(template, "model", cachedModel)
|
||||
} else if modelName != "" {
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||
}
|
||||
|
||||
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
|
||||
|
||||
// Extract and set the response ID.
|
||||
template, _ = sjson.Set(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
|
||||
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() {
|
||||
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||
}
|
||||
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
|
||||
}
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
}
|
||||
|
||||
if dataType == "response.reasoning_summary_text.delta" {
|
||||
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", deltaResult.String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", deltaResult.String())
|
||||
}
|
||||
} else if dataType == "response.reasoning_summary_text.done" {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", "\n\n")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", "\n\n")
|
||||
} else if dataType == "response.output_text.delta" {
|
||||
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", deltaResult.String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.content", deltaResult.String())
|
||||
}
|
||||
} else if dataType == "response.completed" {
|
||||
finishReason := "stop"
|
||||
if (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex != -1 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
|
||||
} else if dataType == "response.output_item.added" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Increment index for this new function call item.
|
||||
@@ -138,9 +138,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
@@ -148,59 +148,59 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", "")
|
||||
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
|
||||
|
||||
deltaValue := rootResult.Get("delta").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||
functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.done" {
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
|
||||
// Arguments were already streamed via delta events; nothing to emit.
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Fallback: no delta events were received, emit the full arguments as a single chunk.
|
||||
fullArgs := rootResult.Get("arguments").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||
functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
|
||||
// Tool call was already announced via output_item.added; skip emission.
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Fallback path: model skipped output_item.added, so emit complete tool call now.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
@@ -208,17 +208,17 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToOpenAINonStream converts a non-streaming Codex response to a non-streaming OpenAI response.
|
||||
@@ -233,53 +233,53 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
// Verify this is a response.completed event
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
unixTimestamp := time.Now().Unix()
|
||||
|
||||
responseResult := rootResult.Get("response")
|
||||
|
||||
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelResult := responseResult.Get("model"); modelResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||
template, _ = sjson.SetBytes(template, "model", modelResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
if createdAtResult := responseResult.Get("created_at"); createdAtResult.Exists() {
|
||||
template, _ = sjson.Set(template, "created", createdAtResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "created", createdAtResult.Int())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
template, _ = sjson.SetBytes(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if idResult := responseResult.Get("id"); idResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", idResult.String())
|
||||
template, _ = sjson.SetBytes(template, "id", idResult.String())
|
||||
}
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := responseResult.Get("usage"); usageResult.Exists() {
|
||||
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||
}
|
||||
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
|
||||
}
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,7 +289,7 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
outputArray := outputResult.Array()
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []string
|
||||
var toolCalls [][]byte
|
||||
|
||||
for _, outputItem := range outputArray {
|
||||
outputType := outputItem.Get("type").String()
|
||||
@@ -319,10 +319,10 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
}
|
||||
case "function_call":
|
||||
// Handle function call content
|
||||
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
functionCallTemplate := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||
|
||||
if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIdResult.String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", callIdResult.String())
|
||||
}
|
||||
|
||||
if nameResult := outputItem.Get("name"); nameResult.Exists() {
|
||||
@@ -331,11 +331,11 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if orig, ok := rev[n]; ok {
|
||||
n = orig
|
||||
}
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", n)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", n)
|
||||
}
|
||||
|
||||
if argsResult := outputItem.Get("arguments"); argsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String())
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", argsResult.String())
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, functionCallTemplate)
|
||||
@@ -344,22 +344,22 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
|
||||
// Set content and reasoning content if found
|
||||
if contentText != "" {
|
||||
template, _ = sjson.Set(template, "choices.0.message.content", contentText)
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.content", contentText)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
|
||||
}
|
||||
|
||||
if reasoningText != "" {
|
||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText)
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.reasoning_content", reasoningText)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
|
||||
}
|
||||
|
||||
// Add tool calls if any
|
||||
if len(toolCalls) > 0 {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls", []byte(`[]`))
|
||||
for _, toolCall := range toolCalls {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls.-1", toolCall)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,8 +367,8 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if statusResult := responseResult.Get("status"); statusResult.Exists() {
|
||||
status := statusResult.String()
|
||||
if status == "completed" {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "stop")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *test
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
gotModel := gjson.Get(out[0], "model").String()
|
||||
gotModel := gjson.GetBytes(out[0], "model").String()
|
||||
if gotModel != modelName {
|
||||
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||
}
|
||||
@@ -40,7 +40,7 @@ func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
gotModel := gjson.Get(out[0], "model").String()
|
||||
gotModel := gjson.GetBytes(out[0], "model").String()
|
||||
if gotModel != modelName {
|
||||
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
|
||||
inputResult := gjson.GetBytes(rawJSON, "input")
|
||||
if inputResult.Type == gjson.String {
|
||||
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
|
||||
input, _ := sjson.SetBytes([]byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`), "0.content.0.text", inputResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", input)
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
|
||||
|
||||
@@ -3,7 +3,6 @@ package responses
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -11,23 +10,25 @@ import (
|
||||
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
||||
// to OpenAI Responses SSE events (response.*).
|
||||
|
||||
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
|
||||
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) [][]byte {
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
out := fmt.Sprintf("data: %s", string(rawJSON))
|
||||
return []string{out}
|
||||
out := make([]byte, 0, len(rawJSON)+len("data: "))
|
||||
out = append(out, []byte("data: ")...)
|
||||
out = append(out, rawJSON...)
|
||||
return [][]byte{out}
|
||||
}
|
||||
return []string{string(rawJSON)}
|
||||
return [][]byte{rawJSON}
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
|
||||
// from a non-streaming OpenAI Chat Completions response.
|
||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
|
||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []byte {
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
// Verify this is a response.completed event
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
responseResult := rootResult.Get("response")
|
||||
return responseResult.Raw
|
||||
return []byte(responseResult.Raw)
|
||||
}
|
||||
|
||||
67
internal/translator/common/bytes.go
Normal file
67
internal/translator/common/bytes.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func WrapGeminiCLIResponse(response []byte) []byte {
|
||||
out, err := sjson.SetRawBytes([]byte(`{"response":{}}`), "response", response)
|
||||
if err != nil {
|
||||
return response
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GeminiTokenCountJSON(count int64) []byte {
|
||||
out := make([]byte, 0, 96)
|
||||
out = append(out, `{"totalTokens":`...)
|
||||
out = strconv.AppendInt(out, count, 10)
|
||||
out = append(out, `,"promptTokensDetails":[{"modality":"TEXT","tokenCount":`...)
|
||||
out = strconv.AppendInt(out, count, 10)
|
||||
out = append(out, `}]}`...)
|
||||
return out
|
||||
}
|
||||
|
||||
func ClaudeInputTokensJSON(count int64) []byte {
|
||||
out := make([]byte, 0, 32)
|
||||
out = append(out, `{"input_tokens":`...)
|
||||
out = strconv.AppendInt(out, count, 10)
|
||||
out = append(out, '}')
|
||||
return out
|
||||
}
|
||||
|
||||
func SSEEventData(event string, payload []byte) []byte {
|
||||
out := make([]byte, 0, len(event)+len(payload)+14)
|
||||
out = append(out, "event: "...)
|
||||
out = append(out, event...)
|
||||
out = append(out, '\n')
|
||||
out = append(out, "data: "...)
|
||||
out = append(out, payload...)
|
||||
return out
|
||||
}
|
||||
|
||||
func AppendSSEEventString(out []byte, event, payload string, trailingNewlines int) []byte {
|
||||
out = append(out, "event: "...)
|
||||
out = append(out, event...)
|
||||
out = append(out, '\n')
|
||||
out = append(out, "data: "...)
|
||||
out = append(out, payload...)
|
||||
for i := 0; i < trailingNewlines; i++ {
|
||||
out = append(out, '\n')
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AppendSSEEventBytes(out []byte, event string, payload []byte, trailingNewlines int) []byte {
|
||||
out = append(out, "event: "...)
|
||||
out = append(out, event...)
|
||||
out = append(out, '\n')
|
||||
out = append(out, "data: "...)
|
||||
out = append(out, payload...)
|
||||
for i := 0; i < trailingNewlines; i++ {
|
||||
out = append(out, '\n')
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -6,10 +6,10 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -36,33 +36,32 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
out := []byte(`{"model":"","request":{"contents":[]}}`)
|
||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||
|
||||
// system instruction
|
||||
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
|
||||
systemInstruction := `{"role":"user","parts":[]}`
|
||||
systemInstruction := []byte(`{"role":"user","parts":[]}`)
|
||||
hasSystemParts := false
|
||||
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
|
||||
if systemPromptResult.Get("type").String() == "text" {
|
||||
textResult := systemPromptResult.Get("text")
|
||||
if textResult.Type == gjson.String {
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", textResult.String())
|
||||
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", textResult.String())
|
||||
systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part)
|
||||
hasSystemParts = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if hasSystemParts {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
|
||||
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstruction)
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String())
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.-1.text", systemResult.String())
|
||||
}
|
||||
|
||||
// contents
|
||||
@@ -77,28 +76,28 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
role = "model"
|
||||
}
|
||||
|
||||
contentJSON := `{"role":"","parts":[]}`
|
||||
contentJSON, _ = sjson.Set(contentJSON, "role", role)
|
||||
contentJSON := []byte(`{"role":"","parts":[]}`)
|
||||
contentJSON, _ = sjson.SetBytes(contentJSON, "role", role)
|
||||
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
|
||||
switch contentResult.Get("type").String() {
|
||||
case "text":
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", contentResult.Get("text").String())
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
|
||||
case "tool_use":
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() && gjson.Valid(functionArgs) {
|
||||
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
|
||||
part, _ = sjson.Set(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
|
||||
part, _ = sjson.Set(part, "functionCall.name", functionName)
|
||||
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
part := []byte(`{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`)
|
||||
part, _ = sjson.SetBytes(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
|
||||
part, _ = sjson.SetBytes(part, "functionCall.name", functionName)
|
||||
part, _ = sjson.SetRawBytes(part, "functionCall.args", []byte(functionArgs))
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
}
|
||||
|
||||
case "tool_result":
|
||||
@@ -112,10 +111,10 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
part, _ = sjson.Set(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
|
||||
case "image":
|
||||
source := contentResult.Get("source")
|
||||
@@ -123,21 +122,21 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
mimeType := source.Get("media_type").String()
|
||||
data := source.Get("data").String()
|
||||
if mimeType != "" && data != "" {
|
||||
part := `{"inlineData":{"mime_type":"","data":""}}`
|
||||
part, _ = sjson.Set(part, "inlineData.mime_type", mimeType)
|
||||
part, _ = sjson.Set(part, "inlineData.data", data)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
part := []byte(`{"inlineData":{"mime_type":"","data":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "inlineData.mime_type", mimeType)
|
||||
part, _ = sjson.SetBytes(part, "inlineData.data", data)
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", contentJSON)
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentsResult.String())
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", contentsResult.String())
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", contentJSON)
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -149,26 +148,27 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
|
||||
inputSchemaResult := toolResult.Get("input_schema")
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
inputSchema := inputSchemaResult.Raw
|
||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
tool, _ = sjson.Delete(tool, "type")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
|
||||
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
|
||||
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
|
||||
tool, _ = sjson.DeleteBytes(tool, "strict")
|
||||
tool, _ = sjson.DeleteBytes(tool, "input_examples")
|
||||
tool, _ = sjson.DeleteBytes(tool, "type")
|
||||
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
||||
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
||||
tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming")
|
||||
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
|
||||
if !hasTools {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools", []byte(`[{"functionDeclarations":[]}]`))
|
||||
hasTools = true
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "request.tools.0.functionDeclarations.-1", tool)
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools.0.functionDeclarations.-1", tool)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !hasTools {
|
||||
out, _ = sjson.Delete(out, "request.tools")
|
||||
out, _ = sjson.DeleteBytes(out, "request.tools")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,15 +186,15 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
|
||||
switch toolChoiceType {
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||
case "any":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
case "tool":
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||
if toolChoiceName != "" {
|
||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -206,8 +206,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive", "auto":
|
||||
// For adaptive thinking:
|
||||
@@ -219,25 +219,23 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||
}
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
}
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
|
||||
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
|
||||
}
|
||||
|
||||
outBytes := []byte(out)
|
||||
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
|
||||
|
||||
return outBytes
|
||||
out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -46,8 +48,8 @@ var toolUseIDCounter uint64
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
|
||||
func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
|
||||
func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &Params{
|
||||
HasFirstResponse: false,
|
||||
@@ -59,34 +61,33 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
// Only send message_stop if we have actually output content
|
||||
if (*param).(*Params).HasContent {
|
||||
return []string{
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
|
||||
}
|
||||
return [][]byte{translatorcommon.AppendSSEEventString(nil, "message_stop", `{"type":"message_stop"}`, 3)}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Track whether tools are being used in this response chunk
|
||||
usedTool := false
|
||||
output := ""
|
||||
output := make([]byte, 0, 1024)
|
||||
appendEvent := func(event, payload string) {
|
||||
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
|
||||
}
|
||||
|
||||
// Initialize the streaming session with a message_start event
|
||||
// This is only sent for the very first response chunk to establish the streaming session
|
||||
if !(*param).(*Params).HasFirstResponse {
|
||||
output = "event: message_start\n"
|
||||
|
||||
// Create the initial message structure with default values according to Claude Code API specification
|
||||
// This follows the Claude Code API specification for streaming message initialization
|
||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||
messageStartTemplate := []byte(`{"type":"message_start","message":{"id":"msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet-20241022","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`)
|
||||
|
||||
// Override default values with actual response metadata if available from the Gemini CLI response
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
}
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
|
||||
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
|
||||
}
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
||||
appendEvent("message_start", string(messageStartTemplate))
|
||||
|
||||
(*param).(*Params).HasFirstResponse = true
|
||||
}
|
||||
@@ -109,9 +110,8 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
if partResult.Get("thought").Bool() {
|
||||
// Continue existing thinking block if already in thinking state
|
||||
if (*param).(*Params).ResponseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
(*param).(*Params).HasContent = true
|
||||
} else {
|
||||
// Transition from another state to thinking
|
||||
@@ -122,19 +122,14 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||
(*param).(*Params).ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new thinking content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex))
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
(*param).(*Params).ResponseType = 2 // Set state to thinking
|
||||
(*param).(*Params).HasContent = true
|
||||
}
|
||||
@@ -142,9 +137,8 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// Process regular text content (user-visible output)
|
||||
// Continue existing text block if already in content state
|
||||
if (*param).(*Params).ResponseType == 1 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
(*param).(*Params).HasContent = true
|
||||
} else {
|
||||
// Transition from another state to text content
|
||||
@@ -155,19 +149,14 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||
(*param).(*Params).ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new text content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex))
|
||||
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String())
|
||||
appendEvent("content_block_delta", string(data))
|
||||
(*param).(*Params).ResponseType = 1 // Set state to content
|
||||
(*param).(*Params).HasContent = true
|
||||
}
|
||||
@@ -181,9 +170,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
if (*param).(*Params).ResponseType == 3 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||
(*param).(*Params).ResponseIndex++
|
||||
(*param).(*Params).ResponseType = 0
|
||||
}
|
||||
@@ -197,26 +184,21 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
|
||||
// Close any other existing content block
|
||||
if (*param).(*Params).ResponseType != 0 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||
(*param).(*Params).ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new tool use content block
|
||||
// This creates the structure for a function call in Claude Code format
|
||||
output = output + "event: content_block_start\n"
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex))
|
||||
data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||
data, _ = sjson.SetBytes(data, "content_block.name", fcName)
|
||||
appendEvent("content_block_start", string(data))
|
||||
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
|
||||
appendEvent("content_block_delta", string(data))
|
||||
}
|
||||
(*param).(*Params).ResponseType = 3
|
||||
(*param).(*Params).HasContent = true
|
||||
@@ -231,34 +213,28 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// Only send final events if we have actually output content
|
||||
if (*param).(*Params).HasContent {
|
||||
// Close the final content block
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
|
||||
// Send the final message delta with usage information and stop reason
|
||||
output = output + "event: message_delta\n"
|
||||
output = output + `data: `
|
||||
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||
|
||||
// Create the message delta template with appropriate stop reason
|
||||
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
template := []byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
// Set tool_use stop reason if tools were used in this response
|
||||
if usedTool {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
} else if finish := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
}
|
||||
|
||||
// Include thinking tokens in output token count if present
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
||||
template, _ = sjson.SetBytes(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
||||
|
||||
output = output + template + "\n\n\n"
|
||||
appendEvent("message_delta", string(template))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []string{output}
|
||||
return [][]byte{output}
|
||||
}
|
||||
|
||||
// ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.
|
||||
@@ -270,21 +246,21 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
||||
// - param: A pointer to a parameter object for the conversion.
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Claude-compatible JSON response.
|
||||
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Claude-compatible JSON response.
|
||||
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
_ = originalRequestRawJSON
|
||||
_ = requestRawJSON
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("response.responseId").String())
|
||||
out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String())
|
||||
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||
out, _ = sjson.SetBytes(out, "id", root.Get("response.responseId").String())
|
||||
out, _ = sjson.SetBytes(out, "model", root.Get("response.modelVersion").String())
|
||||
|
||||
inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
|
||||
outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int()
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
|
||||
|
||||
parts := root.Get("response.candidates.0.content.parts")
|
||||
textBuilder := strings.Builder{}
|
||||
@@ -296,9 +272,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
if textBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", textBuilder.String())
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
block := []byte(`{"type":"text","text":""}`)
|
||||
block, _ = sjson.SetBytes(block, "text", textBuilder.String())
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
textBuilder.Reset()
|
||||
}
|
||||
|
||||
@@ -306,9 +282,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
if thinkingBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
block := []byte(`{"type":"thinking","thinking":""}`)
|
||||
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
}
|
||||
|
||||
@@ -332,15 +308,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
toolIDCounter++
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
|
||||
inputRaw := "{}"
|
||||
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
|
||||
inputRaw = args.Raw
|
||||
}
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
|
||||
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
|
||||
out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -364,15 +340,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
}
|
||||
}
|
||||
}
|
||||
out, _ = sjson.Set(out, "stop_reason", stopReason)
|
||||
out, _ = sjson.SetBytes(out, "stop_reason", stopReason)
|
||||
|
||||
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() {
|
||||
out, _ = sjson.Delete(out, "usage")
|
||||
out, _ = sjson.DeleteBytes(out, "usage")
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.ClaudeInputTokensJSON(count)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -33,23 +34,23 @@ import (
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
template := []byte(`{"project":"","request":{},"model":""}`)
|
||||
template, _ = sjson.SetRawBytes(template, "request", rawJSON)
|
||||
template, _ = sjson.SetBytes(template, "model", gjson.GetBytes(template, "request.model").String())
|
||||
template, _ = sjson.DeleteBytes(template, "request.model")
|
||||
|
||||
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
||||
templateStr, errFixCLIToolResponse := fixCLIToolResponse(string(template))
|
||||
if errFixCLIToolResponse != nil {
|
||||
return []byte{}
|
||||
}
|
||||
template = []byte(templateStr)
|
||||
|
||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||
systemInstructionResult := gjson.GetBytes(template, "request.system_instruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||
template, _ = sjson.SetRawBytes(template, "request.systemInstruction", []byte(systemInstructionResult.Raw))
|
||||
template, _ = sjson.DeleteBytes(template, "request.system_instruction")
|
||||
}
|
||||
rawJSON = []byte(template)
|
||||
rawJSON = template
|
||||
|
||||
// Normalize roles in request.contents: default to valid values if missing/invalid
|
||||
contents := gjson.GetBytes(rawJSON, "request.contents")
|
||||
@@ -110,12 +111,41 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
|
||||
return true
|
||||
})
|
||||
|
||||
// Filter out contents with empty parts to avoid Gemini API error:
|
||||
// "required oneof field 'data' must have one initialized field"
|
||||
filteredContents := []byte(`[]`)
|
||||
hasFiltered := false
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(_, content gjson.Result) bool {
|
||||
parts := content.Get("parts")
|
||||
if !parts.IsArray() || len(parts.Array()) == 0 {
|
||||
hasFiltered = true
|
||||
return true
|
||||
}
|
||||
filteredContents, _ = sjson.SetRawBytes(filteredContents, "-1", []byte(content.Raw))
|
||||
return true
|
||||
})
|
||||
if hasFiltered {
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents", filteredContents)
|
||||
}
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||
}
|
||||
|
||||
// FunctionCallGroup represents a group of function calls and their responses
|
||||
type FunctionCallGroup struct {
|
||||
ResponsesNeeded int
|
||||
CallNames []string // ordered function call names for backfilling empty response names
|
||||
}
|
||||
|
||||
// backfillFunctionResponseName ensures that a functionResponse JSON object has a non-empty name,
|
||||
// falling back to fallbackName if the original is empty.
|
||||
func backfillFunctionResponseName(raw string, fallbackName string) string {
|
||||
name := gjson.Get(raw, "functionResponse.name").String()
|
||||
if strings.TrimSpace(name) == "" && fallbackName != "" {
|
||||
rawBytes, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
|
||||
raw = string(rawBytes)
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
@@ -142,7 +172,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
}
|
||||
|
||||
// Initialize data structures for processing and grouping
|
||||
contentsWrapper := `{"contents":[]}`
|
||||
contentsWrapper := []byte(`{"contents":[]}`)
|
||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||
|
||||
@@ -165,31 +195,28 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
if len(responsePartsInThisContent) > 0 {
|
||||
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||
|
||||
// Check if any pending groups can be satisfied
|
||||
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||
group := pendingGroups[i]
|
||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
// Check if pending groups can be satisfied (FIFO: oldest group first)
|
||||
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
|
||||
group := pendingGroups[0]
|
||||
pendingGroups = pendingGroups[1:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
if !response.IsObject() {
|
||||
log.Warnf("failed to parse function response")
|
||||
continue
|
||||
}
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
||||
for ri, response := range groupResponses {
|
||||
if !response.IsObject() {
|
||||
log.Warnf("failed to parse function response")
|
||||
continue
|
||||
}
|
||||
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
|
||||
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(raw))
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
|
||||
// Remove this group as it's been satisfied
|
||||
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||
break
|
||||
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,25 +225,26 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
// If this is a model with function calls, create a new group
|
||||
if role == "model" {
|
||||
functionCallsCount := 0
|
||||
var callNames []string
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
functionCallsCount++
|
||||
callNames = append(callNames, part.Get("functionCall.name").String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if functionCallsCount > 0 {
|
||||
if len(callNames) > 0 {
|
||||
// Add the model content
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse model content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
ResponsesNeeded: functionCallsCount,
|
||||
ResponsesNeeded: len(callNames),
|
||||
CallNames: callNames,
|
||||
}
|
||||
pendingGroups = append(pendingGroups, group)
|
||||
} else {
|
||||
@@ -225,7 +253,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
}
|
||||
} else {
|
||||
// Non-model content (user, etc.)
|
||||
@@ -233,7 +261,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -245,24 +273,25 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
||||
for ri, response := range groupResponses {
|
||||
if !response.IsObject() {
|
||||
log.Warnf("failed to parse function response")
|
||||
continue
|
||||
}
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
|
||||
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
|
||||
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(raw))
|
||||
}
|
||||
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the original JSON with the new contents
|
||||
result := input
|
||||
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
|
||||
result := []byte(input)
|
||||
result, _ = sjson.SetRawBytes(result, "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
|
||||
|
||||
return result, nil
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ package gemini
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -29,8 +29,8 @@ import (
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - []string: The transformed request data in Gemini API format
|
||||
func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
|
||||
// - [][]byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
}
|
||||
@@ -43,22 +43,22 @@ func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalReq
|
||||
chunk = []byte(responseResult.Raw)
|
||||
}
|
||||
} else {
|
||||
chunkTemplate := "[]"
|
||||
chunkTemplate := []byte(`[]`)
|
||||
responseResult := gjson.ParseBytes(chunk)
|
||||
if responseResult.IsArray() {
|
||||
responseResultItems := responseResult.Array()
|
||||
for i := 0; i < len(responseResultItems); i++ {
|
||||
responseResultItem := responseResultItems[i]
|
||||
if responseResultItem.Get("response").Exists() {
|
||||
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
||||
chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
chunk = []byte(chunkTemplate)
|
||||
chunk = chunkTemplate
|
||||
}
|
||||
return []string{string(chunk)}
|
||||
return [][]byte{chunk}
|
||||
}
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// ConvertGeminiCliResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
|
||||
@@ -72,15 +72,15 @@ func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalReq
|
||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||
//
|
||||
// Returns:
|
||||
// - string: A Gemini-compatible JSON response containing the response data
|
||||
func ConvertGeminiCliResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
// - []byte: A Gemini-compatible JSON response containing the response data
|
||||
func ConvertGeminiCliResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
return responseResult.Raw
|
||||
return []byte(responseResult.Raw)
|
||||
}
|
||||
return string(rawJSON)
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
func GeminiTokenCount(ctx context.Context, count int64) []byte {
|
||||
return translatorcommon.GeminiTokenCountJSON(count)
|
||||
}
|
||||
|
||||
@@ -299,43 +299,43 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
if t.Get("type").String() == "function" {
|
||||
fn := t.Get("function")
|
||||
if fn.Exists() && fn.IsObject() {
|
||||
fnRaw := fn.Raw
|
||||
fnRaw := []byte(fn.Raw)
|
||||
if fn.Get("parameters").Exists() {
|
||||
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
|
||||
renamed, errRename := util.RenameKey(fn.Raw, "parameters", "parametersJsonSchema")
|
||||
if errRename != nil {
|
||||
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
|
||||
var errSet error
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
||||
fnRaw, errSet = sjson.SetBytes(fnRaw, "parametersJsonSchema.type", "object")
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
fnRaw, errSet = sjson.SetRawBytes(fnRaw, "parametersJsonSchema.properties", []byte(`{}`))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
fnRaw = renamed
|
||||
fnRaw = []byte(renamed)
|
||||
}
|
||||
} else {
|
||||
var errSet error
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
||||
fnRaw, errSet = sjson.SetBytes(fnRaw, "parametersJsonSchema.type", "object")
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
fnRaw, errSet = sjson.SetRawBytes(fnRaw, "parametersJsonSchema.properties", []byte(`{}`))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
}
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
fnRaw, _ = sjson.DeleteBytes(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
}
|
||||
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", fnRaw)
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
|
||||
@@ -41,8 +41,8 @@ var functionCallIDCounter uint64
|
||||
// - param: A pointer to a parameter object for maintaining state between calls
|
||||
//
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||
func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// - [][]byte: A slice of OpenAI-compatible JSON responses
|
||||
func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &convertCliResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
@@ -51,15 +51,15 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
}
|
||||
|
||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return []string{}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
template, _ = sjson.SetBytes(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
@@ -68,14 +68,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if err == nil {
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
finishReason := ""
|
||||
@@ -93,21 +93,21 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("gemini-cli openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
@@ -145,33 +145,33 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
hasFunctionCall = true
|
||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
|
||||
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
||||
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
||||
functionCallIndex = len(toolCallsResult.Array())
|
||||
} else {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||
}
|
||||
|
||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||
functionCallTemplate := []byte(`{"id":"","index":0,"type":"function","function":{"name":"","arguments":""}}`)
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||
} else if inlineDataResult.Exists() {
|
||||
data := inlineDataResult.Get("data").String()
|
||||
if data == "" {
|
||||
@@ -185,32 +185,32 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
imagesResult := gjson.GetBytes(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`))
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array())
|
||||
imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`)
|
||||
imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
|
||||
// Only pass through specific finish reasons
|
||||
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
|
||||
}
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
return [][]byte{template}
|
||||
}
|
||||
|
||||
// ConvertCliResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
|
||||
@@ -225,11 +225,11 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
// - param: A pointer to a parameter object for the conversion
|
||||
//
|
||||
// Returns:
|
||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertCliResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertCliResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
|
||||
}
|
||||
return ""
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user