mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-29 16:54:41 +00:00
Compare commits
156 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
851712a49e | ||
|
|
9e34323a40 | ||
|
|
70897247b2 | ||
|
|
9c341f5aa5 | ||
|
|
e3e741d0be | ||
|
|
7c7c5fd967 | ||
|
|
fe8c7a62aa | ||
|
|
2af4a8dc12 | ||
|
|
0f53b952b2 | ||
|
|
7b2ae7377a | ||
|
|
c2ab288c7d | ||
|
|
dbb433fcf8 | ||
|
|
2abf00b5a6 | ||
|
|
275839e5c9 | ||
|
|
f30ffd5f5e | ||
|
|
bc9a24d705 | ||
|
|
2c879f13ef | ||
|
|
07b4a08979 | ||
|
|
497339f055 | ||
|
|
7f612bb069 | ||
|
|
5743b78694 | ||
|
|
2e6a2b655c | ||
|
|
cb47ac21bf | ||
|
|
a1394b4596 | ||
|
|
9e97948f03 | ||
|
|
8f780e7280 | ||
|
|
46c6fb1e7a | ||
|
|
9f9fec5d4c | ||
|
|
e95be10485 | ||
|
|
f3d58fa0ce | ||
|
|
8c0eaa1f71 | ||
|
|
405df58f72 | ||
|
|
e7f13aa008 | ||
|
|
7cb6a9b89a | ||
|
|
9aa5344c29 | ||
|
|
8ba0ebbd2a | ||
|
|
c65407ab9f | ||
|
|
9e59685212 | ||
|
|
4a4dfaa910 | ||
|
|
0d6ecb0191 | ||
|
|
f16461bfe7 | ||
|
|
9fccc86b71 | ||
|
|
74683560a7 | ||
|
|
1e4f9dd438 | ||
|
|
b9ff916494 | ||
|
|
9bf4a0cad2 | ||
|
|
c32e2a8196 | ||
|
|
873d41582f | ||
|
|
6fb7d85558 | ||
|
|
d5e3e32d58 | ||
|
|
f353a54555 | ||
|
|
1d6e2e751d | ||
|
|
cc50b63422 | ||
|
|
15ae83a15b | ||
|
|
81b369aed9 | ||
|
|
ecc850bfb7 | ||
|
|
19b4ef33e0 | ||
|
|
7ca045d8b9 | ||
|
|
25b9df478c | ||
|
|
abfca6aab2 | ||
|
|
3c71c075db | ||
|
|
9c2992bfb2 | ||
|
|
269a1c5452 | ||
|
|
22ce65ac72 | ||
|
|
a2f8f59192 | ||
|
|
8c7c446f33 | ||
|
|
51611c25d7 | ||
|
|
eb1bbaa63b | ||
|
|
30a59168d7 | ||
|
|
4c8026ac3d | ||
|
|
8aeb4b7d54 | ||
|
|
b2172cb047 | ||
|
|
c8884f5e25 | ||
|
|
d9c6317c84 | ||
|
|
d29ec95526 | ||
|
|
ef4508dbc8 | ||
|
|
f775e46fe2 | ||
|
|
65ad5c0c9d | ||
|
|
88bf4e77ec | ||
|
|
194f66ca9c | ||
|
|
a4f8015caa | ||
|
|
ffd129909e | ||
|
|
9332316383 | ||
|
|
6dcbbf64c3 | ||
|
|
c9aa1ff99d | ||
|
|
2ce3553612 | ||
|
|
2e14f787d4 | ||
|
|
523b41ccd2 | ||
|
|
09970dc7af | ||
|
|
d81abd401c | ||
|
|
a6cba25bc1 | ||
|
|
c6fa1d0e67 | ||
|
|
ac56e1e88b | ||
|
|
a9ee971e1c | ||
|
|
73cef3a25a | ||
|
|
9b72ea9efa | ||
|
|
9f364441e8 | ||
|
|
e49a1c07bf | ||
|
|
5364a2471d | ||
|
|
fef4fdb0eb | ||
|
|
c2bf600a39 | ||
|
|
8d9f4edf9b | ||
|
|
020e61d0da | ||
|
|
6184c43319 | ||
|
|
2cbe4a790c | ||
|
|
68b3565d7b | ||
|
|
3f385a8572 | ||
|
|
9823dc35e1 | ||
|
|
059bfee91b | ||
|
|
7beaf0eaa2 | ||
|
|
1fef90ff58 | ||
|
|
8447fd27a0 | ||
|
|
7831cba9f6 | ||
|
|
e02b2d58d5 | ||
|
|
28726632a9 | ||
|
|
0f63d973be | ||
|
|
3b26129c82 | ||
|
|
d4bb4e6624 | ||
|
|
fa2abd560a | ||
|
|
0766c49f93 | ||
|
|
a7ffc77e3d | ||
|
|
e641fde25c | ||
|
|
564c2d763e | ||
|
|
5717c7f2f4 | ||
|
|
8734d4cb90 | ||
|
|
2f6004d74a | ||
|
|
08779cc8a8 | ||
|
|
5baa753539 | ||
|
|
92fb6b012a | ||
|
|
ead98e4bca | ||
|
|
a1634909e8 | ||
|
|
8f06f6a9ed | ||
|
|
ace7c0ccb4 | ||
|
|
f87fe0a0e8 | ||
|
|
87edc6f35e | ||
|
|
1d2fe55310 | ||
|
|
c175821cc4 | ||
|
|
239a28793c | ||
|
|
c421d653e7 | ||
|
|
2542c2920d | ||
|
|
52e46ced1b | ||
|
|
cf9daf470c | ||
|
|
c9301a6d18 | ||
|
|
0e77e93e5d | ||
|
|
5977af96a0 | ||
|
|
5bb9c2a2bd | ||
|
|
0b5bbe9234 | ||
|
|
14c74e5e84 | ||
|
|
6448d0ee7c | ||
|
|
b0c17af2cf | ||
|
|
aa8526edc0 | ||
|
|
ac3ca0ad8e | ||
|
|
08d21b76e2 | ||
|
|
33aa665555 | ||
|
|
00280b6fe8 | ||
|
|
52760a4eaa |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -50,3 +50,4 @@ _bmad-output/*
|
|||||||
# macOS
|
# macOS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
._*
|
._*
|
||||||
|
*.bak
|
||||||
|
|||||||
76
README.md
76
README.md
@@ -13,6 +13,82 @@ The Plus release stays in lockstep with the mainline features.
|
|||||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
||||||
|
|
||||||
|
## New Features (Plus Enhanced)
|
||||||
|
|
||||||
|
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
|
||||||
|
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
|
||||||
|
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
|
||||||
|
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
|
||||||
|
- **Device Fingerprint**: Device fingerprint generation for enhanced security
|
||||||
|
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
|
||||||
|
- **Usage Checker**: Real-time usage monitoring and quota management
|
||||||
|
- **Model Converter**: Unified model name conversion across providers
|
||||||
|
- **UTF-8 Stream Processing**: Improved streaming response handling
|
||||||
|
|
||||||
|
## Kiro Authentication
|
||||||
|
|
||||||
|
### Web-based OAuth Login
|
||||||
|
|
||||||
|
Access the Kiro OAuth web interface at:
|
||||||
|
|
||||||
|
```
|
||||||
|
http://your-server:8080/v0/oauth/kiro
|
||||||
|
```
|
||||||
|
|
||||||
|
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
||||||
|
- AWS Builder ID login
|
||||||
|
- AWS Identity Center (IDC) login
|
||||||
|
- Token import from Kiro IDE
|
||||||
|
|
||||||
|
## Quick Deployment with Docker
|
||||||
|
|
||||||
|
### One-Command Deployment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Create deployment directory
|
||||||
|
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||||
|
|
||||||
|
# Create docker-compose.yml
|
||||||
|
cat > docker-compose.yml << 'EOF'
|
||||||
|
services:
|
||||||
|
cli-proxy-api:
|
||||||
|
image: 17600006524/cli-proxy-api-plus:latest
|
||||||
|
container_name: cli-proxy-api-plus
|
||||||
|
ports:
|
||||||
|
- "8317:8317"
|
||||||
|
volumes:
|
||||||
|
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||||
|
- ./auths:/root/.cli-proxy-api
|
||||||
|
- ./logs:/CLIProxyAPI/logs
|
||||||
|
restart: unless-stopped
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Download example config
|
||||||
|
curl -o config.yaml https://raw.githubusercontent.com/linlang781/CLIProxyAPIPlus/main/config.example.yaml
|
||||||
|
|
||||||
|
# Pull and start
|
||||||
|
docker compose pull && docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Edit `config.yaml` before starting:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Basic configuration example
|
||||||
|
server:
|
||||||
|
port: 8317
|
||||||
|
|
||||||
|
# Add your provider configurations here
|
||||||
|
```
|
||||||
|
|
||||||
|
### Update to Latest Version
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ~/cli-proxy
|
||||||
|
docker compose pull && docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||||
|
|||||||
76
README_CN.md
76
README_CN.md
@@ -13,6 +13,82 @@
|
|||||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
||||||
|
|
||||||
|
## 新增功能 (Plus 增强版)
|
||||||
|
|
||||||
|
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
|
||||||
|
- **请求限流器**: 内置请求限流,防止 API 滥用
|
||||||
|
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
|
||||||
|
- **监控指标**: 请求指标收集,用于监控和调试
|
||||||
|
- **设备指纹**: 设备指纹生成,增强安全性
|
||||||
|
- **冷却管理**: 智能冷却机制,应对 API 速率限制
|
||||||
|
- **用量检查器**: 实时用量监控和配额管理
|
||||||
|
- **模型转换器**: 跨供应商的统一模型名称转换
|
||||||
|
- **UTF-8 流处理**: 改进的流式响应处理
|
||||||
|
|
||||||
|
## Kiro 认证
|
||||||
|
|
||||||
|
### 网页端 OAuth 登录
|
||||||
|
|
||||||
|
访问 Kiro OAuth 网页认证界面:
|
||||||
|
|
||||||
|
```
|
||||||
|
http://your-server:8080/v0/oauth/kiro
|
||||||
|
```
|
||||||
|
|
||||||
|
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
||||||
|
- AWS Builder ID 登录
|
||||||
|
- AWS Identity Center (IDC) 登录
|
||||||
|
- 从 Kiro IDE 导入令牌
|
||||||
|
|
||||||
|
## Docker 快速部署
|
||||||
|
|
||||||
|
### 一键部署
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建部署目录
|
||||||
|
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||||
|
|
||||||
|
# 创建 docker-compose.yml
|
||||||
|
cat > docker-compose.yml << 'EOF'
|
||||||
|
services:
|
||||||
|
cli-proxy-api:
|
||||||
|
image: 17600006524/cli-proxy-api-plus:latest
|
||||||
|
container_name: cli-proxy-api-plus
|
||||||
|
ports:
|
||||||
|
- "8317:8317"
|
||||||
|
volumes:
|
||||||
|
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||||
|
- ./auths:/root/.cli-proxy-api
|
||||||
|
- ./logs:/CLIProxyAPI/logs
|
||||||
|
restart: unless-stopped
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# 下载示例配置
|
||||||
|
curl -o config.yaml https://raw.githubusercontent.com/linlang781/CLIProxyAPIPlus/main/config.example.yaml
|
||||||
|
|
||||||
|
# 拉取并启动
|
||||||
|
docker compose pull && docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置说明
|
||||||
|
|
||||||
|
启动前请编辑 `config.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# 基本配置示例
|
||||||
|
server:
|
||||||
|
port: 8317
|
||||||
|
|
||||||
|
# 在此添加你的供应商配置
|
||||||
|
```
|
||||||
|
|
||||||
|
### 更新到最新版本
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ~/cli-proxy
|
||||||
|
docker compose pull && docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
## 贡献
|
## 贡献
|
||||||
|
|
||||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -533,6 +534,13 @@ func main() {
|
|||||||
}
|
}
|
||||||
// Start the main proxy service
|
// Start the main proxy service
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
|
||||||
|
// 初始化并启动 Kiro token 后台刷新
|
||||||
|
if cfg.AuthDir != "" {
|
||||||
|
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||||
|
defer kiro.StopGlobalRefreshManager()
|
||||||
|
}
|
||||||
|
|
||||||
cmd.StartService(cfg, configFilePath, password)
|
cmd.StartService(cfg, configFilePath, password)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,6 +146,15 @@ codex-instructions-enabled: false
|
|||||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||||
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||||
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
||||||
|
# cloak: # optional: request cloaking for non-Claude-Code clients
|
||||||
|
# mode: "auto" # "auto" (default): cloak only when client is not Claude Code
|
||||||
|
# # "always": always apply cloaking
|
||||||
|
# # "never": never apply cloaking
|
||||||
|
# strict-mode: false # false (default): prepend Claude Code prompt to user system messages
|
||||||
|
# # true: strip all user system messages, keep only Claude Code prompt
|
||||||
|
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||||
|
# - "API"
|
||||||
|
# - "proxy"
|
||||||
|
|
||||||
# Kiro (AWS CodeWhisperer) configuration
|
# Kiro (AWS CodeWhisperer) configuration
|
||||||
# Note: Kiro API currently only operates in us-east-1 region
|
# Note: Kiro API currently only operates in us-east-1 region
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -21,6 +21,7 @@ require (
|
|||||||
golang.org/x/crypto v0.45.0
|
golang.org/x/crypto v0.45.0
|
||||||
golang.org/x/net v0.47.0
|
golang.org/x/net v0.47.0
|
||||||
golang.org/x/oauth2 v0.30.0
|
golang.org/x/oauth2 v0.30.0
|
||||||
|
golang.org/x/sync v0.18.0
|
||||||
golang.org/x/term v0.37.0
|
golang.org/x/term v0.37.0
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
@@ -39,6 +40,7 @@ require (
|
|||||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/emirpasic/gods v1.18.1 // indirect
|
github.com/emirpasic/gods v1.18.1 // indirect
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||||
@@ -68,8 +70,8 @@ require (
|
|||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/sync v0.18.0 // indirect
|
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.1 // indirect
|
google.golang.org/protobuf v1.34.1 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -35,6 +35,8 @@ github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc
|
|||||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
@@ -157,6 +159,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -70,7 +71,7 @@ type apiCallResponse struct {
|
|||||||
// - Authorization: Bearer <key>
|
// - Authorization: Bearer <key>
|
||||||
// - X-Management-Key: <key>
|
// - X-Management-Key: <key>
|
||||||
//
|
//
|
||||||
// Request JSON:
|
// Request JSON (supports both application/json and application/cbor):
|
||||||
// - auth_index / authIndex / AuthIndex (optional):
|
// - auth_index / authIndex / AuthIndex (optional):
|
||||||
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
|
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
|
||||||
// If omitted or not found, credential-specific proxy/token substitution is skipped.
|
// If omitted or not found, credential-specific proxy/token substitution is skipped.
|
||||||
@@ -90,10 +91,12 @@ type apiCallResponse struct {
|
|||||||
// 2. Global config proxy-url
|
// 2. Global config proxy-url
|
||||||
// 3. Direct connect (environment proxies are not used)
|
// 3. Direct connect (environment proxies are not used)
|
||||||
//
|
//
|
||||||
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
|
// Response (returned with HTTP 200 when the APICall itself succeeds):
|
||||||
// - status_code: Upstream HTTP status code.
|
//
|
||||||
// - header: Upstream response headers.
|
// Format matches request Content-Type (application/json or application/cbor)
|
||||||
// - body: Upstream response body as string.
|
// - status_code: Upstream HTTP status code.
|
||||||
|
// - header: Upstream response headers.
|
||||||
|
// - body: Upstream response body as string.
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
@@ -107,10 +110,28 @@ type apiCallResponse struct {
|
|||||||
// -H "Content-Type: application/json" \
|
// -H "Content-Type: application/json" \
|
||||||
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
|
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
|
||||||
func (h *Handler) APICall(c *gin.Context) {
|
func (h *Handler) APICall(c *gin.Context) {
|
||||||
|
// Detect content type
|
||||||
|
contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
|
||||||
|
isCBOR := strings.Contains(contentType, "application/cbor")
|
||||||
|
|
||||||
var body apiCallRequest
|
var body apiCallRequest
|
||||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
// Parse request body based on content type
|
||||||
return
|
if isCBOR {
|
||||||
|
rawBody, errRead := io.ReadAll(c.Request.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errUnmarshal := cbor.Unmarshal(rawBody, &body); errUnmarshal != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid cbor body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
method := strings.ToUpper(strings.TrimSpace(body.Method))
|
method := strings.ToUpper(strings.TrimSpace(body.Method))
|
||||||
@@ -209,11 +230,23 @@ func (h *Handler) APICall(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, apiCallResponse{
|
response := apiCallResponse{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
Header: resp.Header,
|
Header: resp.Header,
|
||||||
Body: string(respBody),
|
Body: string(respBody),
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Return response in the same format as the request
|
||||||
|
if isCBOR {
|
||||||
|
cborData, errMarshal := cbor.Marshal(response)
|
||||||
|
if errMarshal != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encode cbor response"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/cbor", cborData)
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func firstNonEmptyString(values ...*string) string {
|
func firstNonEmptyString(values ...*string) string {
|
||||||
|
|||||||
149
internal/api/handlers/management/api_tools_cbor_test.go
Normal file
149
internal/api/handlers/management/api_tools_cbor_test.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAPICall_CBOR_Support(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Create a test handler
|
||||||
|
h := &Handler{}
|
||||||
|
|
||||||
|
// Create test request data
|
||||||
|
reqData := apiCallRequest{
|
||||||
|
Method: "GET",
|
||||||
|
URL: "https://httpbin.org/get",
|
||||||
|
Header: map[string]string{
|
||||||
|
"User-Agent": "test-client",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("JSON request and response", func(t *testing.T) {
|
||||||
|
// Marshal request as JSON
|
||||||
|
jsonData, err := json.Marshal(reqData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP request
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(jsonData))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Create response recorder
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Create Gin context
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
h.APICall(c)
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
|
||||||
|
t.Logf("Response status: %d", w.Code)
|
||||||
|
t.Logf("Response body: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check content type
|
||||||
|
contentType := w.Header().Get("Content-Type")
|
||||||
|
if w.Code == http.StatusOK && !contains(contentType, "application/json") {
|
||||||
|
t.Errorf("Expected JSON response, got: %s", contentType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CBOR request and response", func(t *testing.T) {
|
||||||
|
// Marshal request as CBOR
|
||||||
|
cborData, err := cbor.Marshal(reqData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal CBOR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP request
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(cborData))
|
||||||
|
req.Header.Set("Content-Type", "application/cbor")
|
||||||
|
|
||||||
|
// Create response recorder
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Create Gin context
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
h.APICall(c)
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
|
||||||
|
t.Logf("Response status: %d", w.Code)
|
||||||
|
t.Logf("Response body: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check content type
|
||||||
|
contentType := w.Header().Get("Content-Type")
|
||||||
|
if w.Code == http.StatusOK && !contains(contentType, "application/cbor") {
|
||||||
|
t.Errorf("Expected CBOR response, got: %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to decode CBOR response
|
||||||
|
if w.Code == http.StatusOK {
|
||||||
|
var response apiCallResponse
|
||||||
|
if err := cbor.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Errorf("Failed to unmarshal CBOR response: %v", err)
|
||||||
|
} else {
|
||||||
|
t.Logf("CBOR response decoded successfully: status_code=%d", response.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CBOR encoding and decoding consistency", func(t *testing.T) {
|
||||||
|
// Test data
|
||||||
|
testReq := apiCallRequest{
|
||||||
|
Method: "POST",
|
||||||
|
URL: "https://example.com/api",
|
||||||
|
Header: map[string]string{
|
||||||
|
"Authorization": "Bearer $TOKEN$",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
Data: `{"key":"value"}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to CBOR
|
||||||
|
cborData, err := cbor.Marshal(testReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal to CBOR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode from CBOR
|
||||||
|
var decoded apiCallRequest
|
||||||
|
if err := cbor.Unmarshal(cborData, &decoded); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal from CBOR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields
|
||||||
|
if decoded.Method != testReq.Method {
|
||||||
|
t.Errorf("Method mismatch: got %s, want %s", decoded.Method, testReq.Method)
|
||||||
|
}
|
||||||
|
if decoded.URL != testReq.URL {
|
||||||
|
t.Errorf("URL mismatch: got %s, want %s", decoded.URL, testReq.URL)
|
||||||
|
}
|
||||||
|
if decoded.Data != testReq.Data {
|
||||||
|
t.Errorf("Data mismatch: got %s, want %s", decoded.Data, testReq.Data)
|
||||||
|
}
|
||||||
|
if len(decoded.Header) != len(testReq.Header) {
|
||||||
|
t.Errorf("Header count mismatch: got %d, want %d", len(decoded.Header), len(testReq.Header))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) > 0 && len(substr) > 0 && (s == substr || len(s) >= len(substr) && s[:len(substr)] == substr || bytes.Contains([]byte(s), []byte(substr)))
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -22,6 +23,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||||
@@ -235,14 +237,6 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
|||||||
log.Infof("callback forwarder on port %d stopped", port)
|
log.Infof("callback forwarder on port %d stopped", port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeAntigravityFileName(email string) string {
|
|
||||||
if strings.TrimSpace(email) == "" {
|
|
||||||
return "antigravity.json"
|
|
||||||
}
|
|
||||||
replacer := strings.NewReplacer("@", "_", ".", "_")
|
|
||||||
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
||||||
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
||||||
return "", fmt.Errorf("server port is not configured")
|
return "", fmt.Errorf("server port is not configured")
|
||||||
@@ -752,6 +746,72 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PatchAuthFileStatus toggles the disabled state of an auth file
|
||||||
|
func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||||
|
if h.authManager == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Disabled *bool `json:"disabled"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimSpace(req.Name)
|
||||||
|
if name == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Disabled == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "disabled is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
// Find auth by name or ID
|
||||||
|
var targetAuth *coreauth.Auth
|
||||||
|
if auth, ok := h.authManager.GetByID(name); ok {
|
||||||
|
targetAuth = auth
|
||||||
|
} else {
|
||||||
|
auths := h.authManager.List()
|
||||||
|
for _, auth := range auths {
|
||||||
|
if auth.FileName == name {
|
||||||
|
targetAuth = auth
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetAuth == nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update disabled state
|
||||||
|
targetAuth.Disabled = *req.Disabled
|
||||||
|
if *req.Disabled {
|
||||||
|
targetAuth.Status = coreauth.StatusDisabled
|
||||||
|
targetAuth.StatusMessage = "disabled via management API"
|
||||||
|
} else {
|
||||||
|
targetAuth.Status = coreauth.StatusActive
|
||||||
|
targetAuth.StatusMessage = ""
|
||||||
|
}
|
||||||
|
targetAuth.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||||
if h == nil || h.authManager == nil {
|
if h == nil || h.authManager == nil {
|
||||||
return
|
return
|
||||||
@@ -918,67 +978,14 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
rawCode := resultMap["code"]
|
rawCode := resultMap["code"]
|
||||||
code := strings.Split(rawCode, "#")[0]
|
code := strings.Split(rawCode, "#")[0]
|
||||||
|
|
||||||
// Exchange code for tokens (replicate logic using updated redirect_uri)
|
// Exchange code for tokens using internal auth service
|
||||||
// Extract client_id from the modified auth URL
|
bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes)
|
||||||
clientID := ""
|
if errExchange != nil {
|
||||||
if u2, errP := url.Parse(authURL); errP == nil {
|
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange)
|
||||||
clientID = u2.Query().Get("client_id")
|
|
||||||
}
|
|
||||||
// Build request
|
|
||||||
bodyMap := map[string]any{
|
|
||||||
"code": code,
|
|
||||||
"state": state,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"client_id": clientID,
|
|
||||||
"redirect_uri": "http://localhost:54545/callback",
|
|
||||||
"code_verifier": pkceCodes.CodeVerifier,
|
|
||||||
}
|
|
||||||
bodyJSON, _ := json.Marshal(bodyMap)
|
|
||||||
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("failed to close response body: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var tResp struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
Account struct {
|
|
||||||
EmailAddress string `json:"email_address"`
|
|
||||||
} `json:"account"`
|
|
||||||
}
|
|
||||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bundle := &claude.ClaudeAuthBundle{
|
|
||||||
TokenData: claude.ClaudeTokenData{
|
|
||||||
AccessToken: tResp.AccessToken,
|
|
||||||
RefreshToken: tResp.RefreshToken,
|
|
||||||
Email: tResp.Account.EmailAddress,
|
|
||||||
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create token storage
|
// Create token storage
|
||||||
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
||||||
@@ -1018,17 +1025,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
|
|
||||||
fmt.Println("Initializing Google authentication...")
|
fmt.Println("Initializing Google authentication...")
|
||||||
|
|
||||||
// OAuth2 configuration (mirrors internal/auth/gemini)
|
// OAuth2 configuration using exported constants from internal/auth/gemini
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com",
|
ClientID: geminiAuth.ClientID,
|
||||||
ClientSecret: "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",
|
ClientSecret: geminiAuth.ClientSecret,
|
||||||
RedirectURL: "http://localhost:8085/oauth2callback",
|
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort),
|
||||||
Scopes: []string{
|
Scopes: geminiAuth.Scopes,
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
Endpoint: google.Endpoint,
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
},
|
|
||||||
Endpoint: google.Endpoint,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build authorization URL and return it immediately
|
// Build authorization URL and return it immediately
|
||||||
@@ -1150,13 +1153,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||||
ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
ifToken["client_id"] = geminiAuth.ClientID
|
||||||
ifToken["client_secret"] = "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"
|
ifToken["client_secret"] = geminiAuth.ClientSecret
|
||||||
ifToken["scopes"] = []string{
|
ifToken["scopes"] = geminiAuth.Scopes
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
}
|
|
||||||
ifToken["universe_domain"] = "googleapis.com"
|
ifToken["universe_domain"] = "googleapis.com"
|
||||||
|
|
||||||
ts := geminiAuth.GeminiTokenStorage{
|
ts := geminiAuth.GeminiTokenStorage{
|
||||||
@@ -1343,74 +1342,34 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("Authorization code received, exchanging for tokens...")
|
log.Debug("Authorization code received, exchanging for tokens...")
|
||||||
// Extract client_id from authURL
|
// Exchange code for tokens using internal auth service
|
||||||
clientID := ""
|
bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes)
|
||||||
if u2, errP := url.Parse(authURL); errP == nil {
|
if errExchange != nil {
|
||||||
clientID = u2.Query().Get("client_id")
|
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange)
|
||||||
}
|
|
||||||
// Exchange code for tokens with redirect equal to mgmtRedirect
|
|
||||||
form := url.Values{
|
|
||||||
"grant_type": {"authorization_code"},
|
|
||||||
"client_id": {clientID},
|
|
||||||
"code": {code},
|
|
||||||
"redirect_uri": {"http://localhost:1455/auth/callback"},
|
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
|
||||||
}
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
|
||||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
// Extract additional info for filename generation
|
||||||
if resp.StatusCode != http.StatusOK {
|
claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken)
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
planType := ""
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
hashAccountID := ""
|
||||||
return
|
|
||||||
}
|
|
||||||
var tokenResp struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
IDToken string `json:"id_token"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
}
|
|
||||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
|
|
||||||
email := ""
|
|
||||||
accountID := ""
|
|
||||||
if claims != nil {
|
if claims != nil {
|
||||||
email = claims.GetUserEmail()
|
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||||
accountID = claims.GetAccountID()
|
if accountID := claims.GetAccountID(); accountID != "" {
|
||||||
}
|
digest := sha256.Sum256([]byte(accountID))
|
||||||
// Build bundle compatible with existing storage
|
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||||
bundle := &codex.CodexAuthBundle{
|
}
|
||||||
TokenData: codex.CodexTokenData{
|
|
||||||
IDToken: tokenResp.IDToken,
|
|
||||||
AccessToken: tokenResp.AccessToken,
|
|
||||||
RefreshToken: tokenResp.RefreshToken,
|
|
||||||
AccountID: accountID,
|
|
||||||
Email: email,
|
|
||||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create token storage and persist
|
// Create token storage and persist
|
||||||
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
|
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
|
||||||
|
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
|
ID: fileName,
|
||||||
Provider: "codex",
|
Provider: "codex",
|
||||||
FileName: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
|
FileName: fileName,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"email": tokenStorage.Email,
|
"email": tokenStorage.Email,
|
||||||
@@ -1436,23 +1395,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||||
const (
|
|
||||||
antigravityCallbackPort = 51121
|
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
|
||||||
antigravityClientSecret = "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"
|
|
||||||
)
|
|
||||||
var antigravityScopes = []string{
|
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
"https://www.googleapis.com/auth/cclog",
|
|
||||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
fmt.Println("Initializing Antigravity authentication...")
|
fmt.Println("Initializing Antigravity authentication...")
|
||||||
|
|
||||||
|
authSvc := antigravity.NewAntigravityAuth(h.cfg, nil)
|
||||||
|
|
||||||
state, errState := misc.GenerateRandomState()
|
state, errState := misc.GenerateRandomState()
|
||||||
if errState != nil {
|
if errState != nil {
|
||||||
log.Errorf("Failed to generate state parameter: %v", errState)
|
log.Errorf("Failed to generate state parameter: %v", errState)
|
||||||
@@ -1460,17 +1408,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
|
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort)
|
||||||
|
authURL := authSvc.BuildAuthURL(state, redirectURI)
|
||||||
params := url.Values{}
|
|
||||||
params.Set("access_type", "offline")
|
|
||||||
params.Set("client_id", antigravityClientID)
|
|
||||||
params.Set("prompt", "consent")
|
|
||||||
params.Set("redirect_uri", redirectURI)
|
|
||||||
params.Set("response_type", "code")
|
|
||||||
params.Set("scope", strings.Join(antigravityScopes, " "))
|
|
||||||
params.Set("state", state)
|
|
||||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
|
||||||
|
|
||||||
RegisterOAuthSession(state, "antigravity")
|
RegisterOAuthSession(state, "antigravity")
|
||||||
|
|
||||||
@@ -1484,7 +1423,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var errStart error
|
var errStart error
|
||||||
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil {
|
||||||
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
return
|
return
|
||||||
@@ -1493,7 +1432,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if isWebUI {
|
if isWebUI {
|
||||||
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
|
defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder)
|
||||||
}
|
}
|
||||||
|
|
||||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
||||||
@@ -1533,93 +1472,36 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI)
|
||||||
form := url.Values{}
|
if errToken != nil {
|
||||||
form.Set("code", authCode)
|
log.Errorf("Failed to exchange token: %v", errToken)
|
||||||
form.Set("client_id", antigravityClientID)
|
|
||||||
form.Set("client_secret", antigravityClientSecret)
|
|
||||||
form.Set("redirect_uri", redirectURI)
|
|
||||||
form.Set("grant_type", "authorization_code")
|
|
||||||
|
|
||||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
|
||||||
if errNewRequest != nil {
|
|
||||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
|
||||||
SetOAuthSessionError(state, "Failed to build token request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
log.Errorf("Failed to execute token request: %v", errDo)
|
|
||||||
SetOAuthSessionError(state, "Failed to exchange token")
|
SetOAuthSessionError(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity token exchange close error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
if accessToken == "" {
|
||||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
log.Error("antigravity: token exchange returned empty access token")
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
SetOAuthSessionError(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenResp struct {
|
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
|
||||||
AccessToken string `json:"access_token"`
|
if errInfo != nil {
|
||||||
RefreshToken string `json:"refresh_token"`
|
log.Errorf("Failed to fetch user info: %v", errInfo)
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
SetOAuthSessionError(state, "Failed to fetch user info")
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
}
|
|
||||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
|
||||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
email := ""
|
if email == "" {
|
||||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
log.Error("antigravity: user info returned empty email")
|
||||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
SetOAuthSessionError(state, "Failed to fetch user info")
|
||||||
if errInfoReq != nil {
|
return
|
||||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
|
||||||
SetOAuthSessionError(state, "Failed to build user info request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
|
||||||
|
|
||||||
infoResp, errInfo := httpClient.Do(infoReq)
|
|
||||||
if errInfo != nil {
|
|
||||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
|
||||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if errClose := infoResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity user info close error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
|
|
||||||
var infoPayload struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
}
|
|
||||||
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
|
|
||||||
email = strings.TrimSpace(infoPayload.Email)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
|
||||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
projectID := ""
|
projectID := ""
|
||||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
if accessToken != "" {
|
||||||
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
|
||||||
if errProject != nil {
|
if errProject != nil {
|
||||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||||
} else {
|
} else {
|
||||||
@@ -1644,7 +1526,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
metadata["project_id"] = projectID
|
metadata["project_id"] = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := sanitizeAntigravityFileName(email)
|
fileName := antigravity.CredentialFileName(email)
|
||||||
label := strings.TrimSpace(email)
|
label := strings.TrimSpace(email)
|
||||||
if label == "" {
|
if label == "" {
|
||||||
label = "antigravity"
|
label = "antigravity"
|
||||||
@@ -2198,7 +2080,20 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
|||||||
finalProjectID := projectID
|
finalProjectID := projectID
|
||||||
if responseProjectID != "" {
|
if responseProjectID != "" {
|
||||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||||
|
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||||
|
strings.EqualFold(tierID, "FREE") ||
|
||||||
|
strings.EqualFold(tierID, "LEGACY")
|
||||||
|
|
||||||
|
if isFreeUser {
|
||||||
|
// For free users, use backend project ID for preview model access
|
||||||
|
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
|
||||||
|
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
|
||||||
|
finalProjectID = responseProjectID
|
||||||
|
} else {
|
||||||
|
// Pro users: keep requested project ID (original behavior)
|
||||||
|
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
finalProjectID = responseProjectID
|
finalProjectID = responseProjectID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||||
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
@@ -292,6 +293,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
s.registerManagementRoutes()
|
s.registerManagementRoutes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 ===
|
||||||
|
kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg)
|
||||||
|
kiroOAuthHandler.RegisterRoutes(engine)
|
||||||
|
log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*")
|
||||||
|
|
||||||
if optionState.keepAliveEnabled {
|
if optionState.keepAliveEnabled {
|
||||||
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
|
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
|
||||||
}
|
}
|
||||||
@@ -630,6 +636,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||||
|
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||||
|
|
||||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||||
|
|||||||
344
internal/auth/antigravity/auth.go
Normal file
344
internal/auth/antigravity/auth.go
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenResponse represents OAuth token response from Google
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// userInfo represents Google user profile
|
||||||
|
type userInfo struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityAuth handles Antigravity OAuth authentication
|
||||||
|
type AntigravityAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAntigravityAuth creates a new Antigravity auth service.
|
||||||
|
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
|
||||||
|
if httpClient != nil {
|
||||||
|
return &AntigravityAuth{httpClient: httpClient}
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &config.Config{}
|
||||||
|
}
|
||||||
|
return &AntigravityAuth{
|
||||||
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildAuthURL generates the OAuth authorization URL.
|
||||||
|
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
|
||||||
|
if strings.TrimSpace(redirectURI) == "" {
|
||||||
|
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
|
||||||
|
}
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("access_type", "offline")
|
||||||
|
params.Set("client_id", ClientID)
|
||||||
|
params.Set("prompt", "consent")
|
||||||
|
params.Set("redirect_uri", redirectURI)
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Set("scope", strings.Join(Scopes, " "))
|
||||||
|
params.Set("state", state)
|
||||||
|
return AuthEndpoint + "?" + params.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
|
||||||
|
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("code", code)
|
||||||
|
data.Set("client_id", ClientID)
|
||||||
|
data.Set("client_secret", ClientSecret)
|
||||||
|
data.Set("redirect_uri", redirectURI)
|
||||||
|
data.Set("grant_type", "authorization_code")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity token exchange: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||||
|
if errRead != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if body == "" {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token TokenResponse
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUserInfo retrieves user email from Google
|
||||||
|
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||||
|
accessToken = strings.TrimSpace(accessToken)
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: missing access token")
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity userinfo: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if body == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
var info userInfo
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
email := strings.TrimSpace(info.Email)
|
||||||
|
if email == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: response missing email")
|
||||||
|
}
|
||||||
|
return email, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
|
||||||
|
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
|
||||||
|
loadReqBody := map[string]any{
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"ideType": "ANTIGRAVITY",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", APIUserAgent)
|
||||||
|
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||||
|
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("read response: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var loadResp map[string]any
|
||||||
|
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract projectID from response
|
||||||
|
projectID := ""
|
||||||
|
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
if projectID == "" {
|
||||||
|
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||||
|
if id, okID := projectMap["id"].(string); okID {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID == "" {
|
||||||
|
tierID := "legacy-tier"
|
||||||
|
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||||
|
for _, rawTier := range tiers {
|
||||||
|
tier, okTier := rawTier.(map[string]any)
|
||||||
|
if !okTier {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||||
|
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||||
|
tierID = strings.TrimSpace(id)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
|
||||||
|
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||||
|
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
|
||||||
|
requestBody := map[string]any{
|
||||||
|
"tierId": tierID,
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"ideType": "ANTIGRAVITY",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody, errMarshal := json.Marshal(requestBody)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAttempts := 5
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||||
|
|
||||||
|
reqCtx := ctx
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
if reqCtx == nil {
|
||||||
|
reqCtx = context.Background()
|
||||||
|
}
|
||||||
|
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||||
|
|
||||||
|
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
|
||||||
|
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||||
|
if errRequest != nil {
|
||||||
|
cancel()
|
||||||
|
return "", fmt.Errorf("create request: %w", errRequest)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", APIUserAgent)
|
||||||
|
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||||
|
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
cancel()
|
||||||
|
return "", fmt.Errorf("execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("read response: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
var data map[string]any
|
||||||
|
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if done, okDone := data["done"].(bool); okDone && done {
|
||||||
|
projectID := ""
|
||||||
|
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||||
|
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if id, okID := projectValue["id"].(string); okID {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
projectID = strings.TrimSpace(projectValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID != "" {
|
||||||
|
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("no project_id in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if len(responsePreview) > 500 {
|
||||||
|
responsePreview = responsePreview[:500]
|
||||||
|
}
|
||||||
|
|
||||||
|
responseErr := responsePreview
|
||||||
|
if len(responseErr) > 200 {
|
||||||
|
responseErr = responseErr[:200]
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
34
internal/auth/antigravity/constants.go
Normal file
34
internal/auth/antigravity/constants.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||||
|
package antigravity
|
||||||
|
|
||||||
|
// OAuth client credentials and configuration
|
||||||
|
const (
|
||||||
|
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
|
ClientSecret = "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"
|
||||||
|
CallbackPort = 51121
|
||||||
|
)
|
||||||
|
|
||||||
|
// Scopes defines the OAuth scopes required for Antigravity authentication
|
||||||
|
var Scopes = []string{
|
||||||
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
|
"https://www.googleapis.com/auth/cclog",
|
||||||
|
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2 endpoints for Google authentication
|
||||||
|
const (
|
||||||
|
TokenEndpoint = "https://oauth2.googleapis.com/token"
|
||||||
|
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Antigravity API configuration
|
||||||
|
const (
|
||||||
|
APIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
|
APIVersion = "v1internal"
|
||||||
|
APIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||||
|
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||||
|
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
||||||
|
)
|
||||||
16
internal/auth/antigravity/filename.go
Normal file
16
internal/auth/antigravity/filename.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used to persist Antigravity credentials.
|
||||||
|
// It uses the email as a suffix to disambiguate accounts.
|
||||||
|
func CredentialFileName(email string) string {
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
if email == "" {
|
||||||
|
return "antigravity.json"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("antigravity-%s.json", email)
|
||||||
|
}
|
||||||
@@ -18,11 +18,12 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for Claude/Anthropic
|
||||||
const (
|
const (
|
||||||
anthropicAuthURL = "https://claude.ai/oauth/authorize"
|
AuthURL = "https://claude.ai/oauth/authorize"
|
||||||
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||||
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||||
redirectURI = "http://localhost:54545/callback"
|
RedirectURI = "http://localhost:54545/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
||||||
@@ -82,16 +83,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
|||||||
|
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"code": {"true"},
|
"code": {"true"},
|
||||||
"client_id": {anthropicClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"org:create_api_key user:profile user:inference"},
|
"scope": {"org:create_api_key user:profile user:inference"},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
"code_challenge_method": {"S256"},
|
"code_challenge_method": {"S256"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
|
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||||
return authURL, state, nil
|
return authURL, state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,8 +138,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
|||||||
"code": newCode,
|
"code": newCode,
|
||||||
"state": state,
|
"state": state,
|
||||||
"grant_type": "authorization_code",
|
"grant_type": "authorization_code",
|
||||||
"client_id": anthropicClientID,
|
"client_id": ClientID,
|
||||||
"redirect_uri": redirectURI,
|
"redirect_uri": RedirectURI,
|
||||||
"code_verifier": pkceCodes.CodeVerifier,
|
"code_verifier": pkceCodes.CodeVerifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,7 +155,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
|||||||
|
|
||||||
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -221,7 +222,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
|||||||
}
|
}
|
||||||
|
|
||||||
reqBody := map[string]interface{}{
|
reqBody := map[string]interface{}{
|
||||||
"client_id": anthropicClientID,
|
"client_id": ClientID,
|
||||||
"grant_type": "refresh_token",
|
"grant_type": "refresh_token",
|
||||||
"refresh_token": refreshToken,
|
"refresh_token": refreshToken,
|
||||||
}
|
}
|
||||||
@@ -231,7 +232,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
|||||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
46
internal/auth/codex/filename.go
Normal file
46
internal/auth/codex/filename.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
|
||||||
|
// When planType is available (e.g. "plus", "team"), it is appended after the email
|
||||||
|
// as a suffix to disambiguate subscriptions.
|
||||||
|
func CredentialFileName(email, planType, hashAccountID string, includeProviderPrefix bool) string {
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
plan := normalizePlanTypeForFilename(planType)
|
||||||
|
|
||||||
|
prefix := ""
|
||||||
|
if includeProviderPrefix {
|
||||||
|
prefix = "codex"
|
||||||
|
}
|
||||||
|
|
||||||
|
if plan == "" {
|
||||||
|
return fmt.Sprintf("%s-%s.json", prefix, email)
|
||||||
|
} else if plan == "team" {
|
||||||
|
return fmt.Sprintf("%s-%s-%s-%s.json", prefix, hashAccountID, email, plan)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s-%s-%s.json", prefix, email, plan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePlanTypeForFilename(planType string) string {
|
||||||
|
planType = strings.TrimSpace(planType)
|
||||||
|
if planType == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.FieldsFunc(planType, func(r rune) bool {
|
||||||
|
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||||
|
})
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, part := range parts {
|
||||||
|
parts[i] = strings.ToLower(strings.TrimSpace(part))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "-")
|
||||||
|
}
|
||||||
@@ -19,11 +19,12 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for OpenAI Codex
|
||||||
const (
|
const (
|
||||||
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
|
AuthURL = "https://auth.openai.com/oauth/authorize"
|
||||||
openaiTokenURL = "https://auth.openai.com/oauth/token"
|
TokenURL = "https://auth.openai.com/oauth/token"
|
||||||
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
redirectURI = "http://localhost:1455/auth/callback"
|
RedirectURI = "http://localhost:1455/auth/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
||||||
@@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"openid email profile offline_access"},
|
"scope": {"openid email profile offline_access"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
@@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
"codex_cli_simplified_flow": {"true"},
|
"codex_cli_simplified_flow": {"true"},
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
|
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||||
return authURL, nil
|
return authURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,13 +78,13 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce
|
|||||||
// Prepare token exchange request
|
// Prepare token exchange request
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"code": {code},
|
"code": {code},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -163,13 +164,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"grant_type": {"refresh_token"},
|
"grant_type": {"refresh_token"},
|
||||||
"refresh_token": {refreshToken},
|
"refresh_token": {refreshToken},
|
||||||
"scope": {"openid profile email"},
|
"scope": {"openid profile email"},
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,19 +28,19 @@ import (
|
|||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for Gemini
|
||||||
const (
|
const (
|
||||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||||
geminiOauthClientSecret = "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"
|
ClientSecret = "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"
|
||||||
geminiDefaultCallbackPort = 8085
|
DefaultCallbackPort = 8085
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// OAuth scopes for Gemini authentication
|
||||||
geminiOauthScopes = []string{
|
var Scopes = []string{
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
||||||
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
||||||
@@ -74,7 +74,7 @@ func NewGeminiAuth() *GeminiAuth {
|
|||||||
// - *http.Client: An HTTP client configured with authentication
|
// - *http.Client: An HTTP client configured with authentication
|
||||||
// - error: An error if the client configuration fails, nil otherwise
|
// - error: An error if the client configuration fails, nil otherwise
|
||||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||||
callbackPort := geminiDefaultCallbackPort
|
callbackPort := DefaultCallbackPort
|
||||||
if opts != nil && opts.CallbackPort > 0 {
|
if opts != nil && opts.CallbackPort > 0 {
|
||||||
callbackPort = opts.CallbackPort
|
callbackPort = opts.CallbackPort
|
||||||
}
|
}
|
||||||
@@ -112,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
|
|
||||||
// Configure the OAuth2 client.
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: geminiOauthClientID,
|
ClientID: ClientID,
|
||||||
ClientSecret: geminiOauthClientSecret,
|
ClientSecret: ClientSecret,
|
||||||
RedirectURL: callbackURL, // This will be used by the local server.
|
RedirectURL: callbackURL, // This will be used by the local server.
|
||||||
Scopes: geminiOauthScopes,
|
Scopes: Scopes,
|
||||||
Endpoint: google.Endpoint,
|
Endpoint: google.Endpoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||||
ifToken["client_id"] = geminiOauthClientID
|
ifToken["client_id"] = ClientID
|
||||||
ifToken["client_secret"] = geminiOauthClientSecret
|
ifToken["client_secret"] = ClientSecret
|
||||||
ifToken["scopes"] = geminiOauthScopes
|
ifToken["scopes"] = Scopes
|
||||||
ifToken["universe_domain"] = "googleapis.com"
|
ifToken["universe_domain"] = "googleapis.com"
|
||||||
|
|
||||||
ts := GeminiTokenStorage{
|
ts := GeminiTokenStorage{
|
||||||
@@ -226,7 +226,7 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
|||||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||||
// - error: An error if the token acquisition fails, nil otherwise
|
// - error: An error if the token acquisition fails, nil otherwise
|
||||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||||
callbackPort := geminiDefaultCallbackPort
|
callbackPort := DefaultCallbackPort
|
||||||
if opts != nil && opts.CallbackPort > 0 {
|
if opts != nil && opts.CallbackPort > 0 {
|
||||||
callbackPort = opts.CallbackPort
|
callbackPort = opts.CallbackPort
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ package kiro
|
|||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||||
@@ -85,6 +87,87 @@ type KiroModel struct {
|
|||||||
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||||
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||||
|
|
||||||
|
// Default retry configuration for file reading
|
||||||
|
const (
|
||||||
|
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
|
||||||
|
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
|
||||||
|
)
|
||||||
|
|
||||||
|
// isTransientFileError checks if the error is a transient file access error
|
||||||
|
// that may be resolved by retrying (e.g., file locked by another process on Windows).
|
||||||
|
func isTransientFileError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for OS-level file access errors (Windows sharing violation, etc.)
|
||||||
|
var pathErr *os.PathError
|
||||||
|
if errors.As(err, &pathErr) {
|
||||||
|
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
|
||||||
|
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
|
||||||
|
errStr := pathErr.Err.Error()
|
||||||
|
if strings.Contains(errStr, "being used by another process") ||
|
||||||
|
strings.Contains(errStr, "sharing violation") ||
|
||||||
|
strings.Contains(errStr, "lock violation") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error message for common transient patterns
|
||||||
|
errMsg := strings.ToLower(err.Error())
|
||||||
|
transientPatterns := []string{
|
||||||
|
"being used by another process",
|
||||||
|
"sharing violation",
|
||||||
|
"lock violation",
|
||||||
|
"access is denied",
|
||||||
|
"unexpected end of json",
|
||||||
|
"unexpected eof",
|
||||||
|
}
|
||||||
|
for _, pattern := range transientPatterns {
|
||||||
|
if strings.Contains(errMsg, pattern) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
|
||||||
|
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
|
||||||
|
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
|
||||||
|
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
|
||||||
|
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
|
||||||
|
if maxAttempts <= 0 {
|
||||||
|
maxAttempts = defaultTokenReadMaxAttempts
|
||||||
|
}
|
||||||
|
if baseDelay <= 0 {
|
||||||
|
baseDelay = defaultTokenReadBaseDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
|
token, err := LoadKiroIDEToken()
|
||||||
|
if err == nil {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
|
||||||
|
// Only retry for transient errors
|
||||||
|
if !isTransientFileError(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exponential backoff: delay * 2^attempt, capped at 500ms
|
||||||
|
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||||
|
if delay > 500*time.Millisecond {
|
||||||
|
delay = 500 * time.Millisecond
|
||||||
|
}
|
||||||
|
time.Sleep(delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||||
homeDir, err := os.UserHomeDir()
|
homeDir, err := os.UserHomeDir()
|
||||||
@@ -107,6 +190,9 @@ func LoadKiroIDEToken() (*KiroTokenData, error) {
|
|||||||
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||||
|
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||||
|
|
||||||
return &token, nil
|
return &token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +222,9 @@ func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
|||||||
return nil, fmt.Errorf("access token is empty in token file")
|
return nil, fmt.Errorf("access token is empty in token file")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||||
|
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||||
|
|
||||||
return &token, nil
|
return &token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -280,6 +280,11 @@ func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorag
|
|||||||
AuthMethod: tokenData.AuthMethod,
|
AuthMethod: tokenData.AuthMethod,
|
||||||
Provider: tokenData.Provider,
|
Provider: tokenData.Provider,
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
LastRefresh: time.Now().Format(time.RFC3339),
|
||||||
|
ClientID: tokenData.ClientID,
|
||||||
|
ClientSecret: tokenData.ClientSecret,
|
||||||
|
Region: tokenData.Region,
|
||||||
|
StartURL: tokenData.StartURL,
|
||||||
|
Email: tokenData.Email,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,4 +316,19 @@ func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *Kiro
|
|||||||
storage.AuthMethod = tokenData.AuthMethod
|
storage.AuthMethod = tokenData.AuthMethod
|
||||||
storage.Provider = tokenData.Provider
|
storage.Provider = tokenData.Provider
|
||||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||||
|
if tokenData.ClientID != "" {
|
||||||
|
storage.ClientID = tokenData.ClientID
|
||||||
|
}
|
||||||
|
if tokenData.ClientSecret != "" {
|
||||||
|
storage.ClientSecret = tokenData.ClientSecret
|
||||||
|
}
|
||||||
|
if tokenData.Region != "" {
|
||||||
|
storage.Region = tokenData.Region
|
||||||
|
}
|
||||||
|
if tokenData.StartURL != "" {
|
||||||
|
storage.StartURL = tokenData.StartURL
|
||||||
|
}
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
storage.Email = tokenData.Email
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
224
internal/auth/kiro/background_refresh.go
Normal file
224
internal/auth/kiro/background_refresh.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
ID string
|
||||||
|
AccessToken string
|
||||||
|
RefreshToken string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
LastVerified time.Time
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
AuthMethod string
|
||||||
|
Provider string
|
||||||
|
StartURL string
|
||||||
|
Region string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenRepository interface {
|
||||||
|
FindOldestUnverified(limit int) []*Token
|
||||||
|
UpdateToken(token *Token) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefresherOption func(*BackgroundRefresher)
|
||||||
|
|
||||||
|
func WithInterval(interval time.Duration) RefresherOption {
|
||||||
|
return func(r *BackgroundRefresher) {
|
||||||
|
r.interval = interval
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBatchSize(size int) RefresherOption {
|
||||||
|
return func(r *BackgroundRefresher) {
|
||||||
|
r.batchSize = size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithConcurrency(concurrency int) RefresherOption {
|
||||||
|
return func(r *BackgroundRefresher) {
|
||||||
|
r.concurrency = concurrency
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackgroundRefresher struct {
|
||||||
|
interval time.Duration
|
||||||
|
batchSize int
|
||||||
|
concurrency int
|
||||||
|
tokenRepo TokenRepository
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
oauth *KiroOAuth
|
||||||
|
ssoClient *SSOOIDCClient
|
||||||
|
callbackMu sync.RWMutex // 保护回调函数的并发访问
|
||||||
|
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
|
||||||
|
r := &BackgroundRefresher{
|
||||||
|
interval: time.Minute,
|
||||||
|
batchSize: 50,
|
||||||
|
concurrency: 10,
|
||||||
|
tokenRepo: repo,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
oauth: nil, // Lazy init - will be set when config available
|
||||||
|
ssoClient: nil, // Lazy init - will be set when config available
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(r)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithConfig sets the configuration for OAuth and SSO clients.
|
||||||
|
func WithConfig(cfg *config.Config) RefresherOption {
|
||||||
|
return func(r *BackgroundRefresher) {
|
||||||
|
r.oauth = NewKiroOAuth(cfg)
|
||||||
|
r.ssoClient = NewSSOOIDCClient(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
|
||||||
|
// The callback receives the token ID (filename) and the new token data.
|
||||||
|
// This allows external components (e.g., Watcher) to be notified of token updates.
|
||||||
|
func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
|
||||||
|
return func(r *BackgroundRefresher) {
|
||||||
|
r.callbackMu.Lock()
|
||||||
|
r.onTokenRefreshed = callback
|
||||||
|
r.callbackMu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *BackgroundRefresher) Start(ctx context.Context) {
|
||||||
|
r.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer r.wg.Done()
|
||||||
|
ticker := time.NewTicker(r.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
r.refreshBatch(ctx)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-r.stopCh:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
r.refreshBatch(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *BackgroundRefresher) Stop() {
|
||||||
|
close(r.stopCh)
|
||||||
|
r.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
|
||||||
|
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sem := semaphore.NewWeighted(int64(r.concurrency))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i, token := range tokens {
|
||||||
|
if i > 0 {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-r.stopCh:
|
||||||
|
return
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sem.Acquire(ctx, 1); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(t *Token) {
|
||||||
|
defer wg.Done()
|
||||||
|
defer sem.Release(1)
|
||||||
|
r.refreshSingle(ctx, t)
|
||||||
|
}(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
||||||
|
var newTokenData *KiroTokenData
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch token.AuthMethod {
|
||||||
|
case "idc":
|
||||||
|
newTokenData, err = r.ssoClient.RefreshTokenWithRegion(
|
||||||
|
ctx,
|
||||||
|
token.ClientID,
|
||||||
|
token.ClientSecret,
|
||||||
|
token.RefreshToken,
|
||||||
|
token.Region,
|
||||||
|
token.StartURL,
|
||||||
|
)
|
||||||
|
case "builder-id":
|
||||||
|
newTokenData, err = r.ssoClient.RefreshToken(
|
||||||
|
ctx,
|
||||||
|
token.ClientID,
|
||||||
|
token.ClientSecret,
|
||||||
|
token.RefreshToken,
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to refresh token %s: %v", token.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token.AccessToken = newTokenData.AccessToken
|
||||||
|
token.RefreshToken = newTokenData.RefreshToken
|
||||||
|
token.LastVerified = time.Now()
|
||||||
|
|
||||||
|
if newTokenData.ExpiresAt != "" {
|
||||||
|
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
|
||||||
|
token.ExpiresAt = expTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.tokenRepo.UpdateToken(token); err != nil {
|
||||||
|
log.Printf("failed to update token %s: %v", token.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
|
||||||
|
r.callbackMu.RLock()
|
||||||
|
callback := r.onTokenRefreshed
|
||||||
|
r.callbackMu.RUnlock()
|
||||||
|
|
||||||
|
if callback != nil {
|
||||||
|
// 使用 defer recover 隔离回调 panic,防止崩溃整个进程
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
|
||||||
|
callback(token.ID, newTokenData)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
112
internal/auth/kiro/cooldown.go
Normal file
112
internal/auth/kiro/cooldown.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CooldownReason429 = "rate_limit_exceeded"
|
||||||
|
CooldownReasonSuspended = "account_suspended"
|
||||||
|
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||||
|
|
||||||
|
DefaultShortCooldown = 1 * time.Minute
|
||||||
|
MaxShortCooldown = 5 * time.Minute
|
||||||
|
LongCooldown = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type CooldownManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
cooldowns map[string]time.Time
|
||||||
|
reasons map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCooldownManager() *CooldownManager {
|
||||||
|
return &CooldownManager{
|
||||||
|
cooldowns: make(map[string]time.Time),
|
||||||
|
reasons: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
|
||||||
|
cm.mu.Lock()
|
||||||
|
defer cm.mu.Unlock()
|
||||||
|
cm.cooldowns[tokenKey] = time.Now().Add(duration)
|
||||||
|
cm.reasons[tokenKey] = reason
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
|
||||||
|
cm.mu.RLock()
|
||||||
|
defer cm.mu.RUnlock()
|
||||||
|
endTime, exists := cm.cooldowns[tokenKey]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().Before(endTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
|
||||||
|
cm.mu.RLock()
|
||||||
|
defer cm.mu.RUnlock()
|
||||||
|
endTime, exists := cm.cooldowns[tokenKey]
|
||||||
|
if !exists {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
remaining := time.Until(endTime)
|
||||||
|
if remaining < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
|
||||||
|
cm.mu.RLock()
|
||||||
|
defer cm.mu.RUnlock()
|
||||||
|
return cm.reasons[tokenKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
|
||||||
|
cm.mu.Lock()
|
||||||
|
defer cm.mu.Unlock()
|
||||||
|
delete(cm.cooldowns, tokenKey)
|
||||||
|
delete(cm.reasons, tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *CooldownManager) CleanupExpired() {
|
||||||
|
cm.mu.Lock()
|
||||||
|
defer cm.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
for tokenKey, endTime := range cm.cooldowns {
|
||||||
|
if now.After(endTime) {
|
||||||
|
delete(cm.cooldowns, tokenKey)
|
||||||
|
delete(cm.reasons, tokenKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
cm.CleanupExpired()
|
||||||
|
case <-stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CalculateCooldownFor429(retryCount int) time.Duration {
|
||||||
|
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
|
||||||
|
if duration > MaxShortCooldown {
|
||||||
|
return MaxShortCooldown
|
||||||
|
}
|
||||||
|
return duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func CalculateCooldownUntilNextDay() time.Duration {
|
||||||
|
now := time.Now()
|
||||||
|
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
|
||||||
|
return time.Until(nextDay)
|
||||||
|
}
|
||||||
240
internal/auth/kiro/cooldown_test.go
Normal file
240
internal/auth/kiro/cooldown_test.go
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewCooldownManager(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
if cm == nil {
|
||||||
|
t.Fatal("expected non-nil CooldownManager")
|
||||||
|
}
|
||||||
|
if cm.cooldowns == nil {
|
||||||
|
t.Error("expected non-nil cooldowns map")
|
||||||
|
}
|
||||||
|
if cm.reasons == nil {
|
||||||
|
t.Error("expected non-nil reasons map")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCooldown(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||||
|
|
||||||
|
if !cm.IsInCooldown("token1") {
|
||||||
|
t.Error("expected token to be in cooldown")
|
||||||
|
}
|
||||||
|
if cm.GetCooldownReason("token1") != CooldownReason429 {
|
||||||
|
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsInCooldown_NotSet(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
if cm.IsInCooldown("nonexistent") {
|
||||||
|
t.Error("expected non-existent token to not be in cooldown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsInCooldown_Expired(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
if cm.IsInCooldown("token1") {
|
||||||
|
t.Error("expected expired cooldown to return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRemainingCooldown(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
|
||||||
|
|
||||||
|
remaining := cm.GetRemainingCooldown("token1")
|
||||||
|
if remaining <= 0 || remaining > 1*time.Second {
|
||||||
|
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRemainingCooldown_NotSet(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
remaining := cm.GetRemainingCooldown("nonexistent")
|
||||||
|
if remaining != 0 {
|
||||||
|
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRemainingCooldown_Expired(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
remaining := cm.GetRemainingCooldown("token1")
|
||||||
|
if remaining != 0 {
|
||||||
|
t.Errorf("expected 0 remaining for expired, got %v", remaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCooldownReason(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||||
|
|
||||||
|
reason := cm.GetCooldownReason("token1")
|
||||||
|
if reason != CooldownReasonSuspended {
|
||||||
|
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCooldownReason_NotSet(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
reason := cm.GetCooldownReason("nonexistent")
|
||||||
|
if reason != "" {
|
||||||
|
t.Errorf("expected empty reason for non-existent, got %s", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearCooldown(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||||
|
cm.ClearCooldown("token1")
|
||||||
|
|
||||||
|
if cm.IsInCooldown("token1") {
|
||||||
|
t.Error("expected cooldown to be cleared")
|
||||||
|
}
|
||||||
|
if cm.GetCooldownReason("token1") != "" {
|
||||||
|
t.Error("expected reason to be cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearCooldown_NonExistent(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.ClearCooldown("nonexistent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpired(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
|
||||||
|
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
|
||||||
|
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
cm.CleanupExpired()
|
||||||
|
|
||||||
|
if cm.GetCooldownReason("expired1") != "" {
|
||||||
|
t.Error("expected expired1 to be cleaned up")
|
||||||
|
}
|
||||||
|
if cm.GetCooldownReason("expired2") != "" {
|
||||||
|
t.Error("expected expired2 to be cleaned up")
|
||||||
|
}
|
||||||
|
if cm.GetCooldownReason("active") != CooldownReason429 {
|
||||||
|
t.Error("expected active to remain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
|
||||||
|
duration := CalculateCooldownFor429(0)
|
||||||
|
if duration != DefaultShortCooldown {
|
||||||
|
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
|
||||||
|
d1 := CalculateCooldownFor429(1)
|
||||||
|
d2 := CalculateCooldownFor429(2)
|
||||||
|
|
||||||
|
if d2 <= d1 {
|
||||||
|
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
|
||||||
|
duration := CalculateCooldownFor429(10)
|
||||||
|
if duration > MaxShortCooldown {
|
||||||
|
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateCooldownUntilNextDay(t *testing.T) {
|
||||||
|
duration := CalculateCooldownUntilNextDay()
|
||||||
|
if duration <= 0 || duration > 24*time.Hour {
|
||||||
|
t.Errorf("expected duration between 0 and 24h, got %v", duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
const numGoroutines = 50
|
||||||
|
const numOperations = 100
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
tokenKey := "token" + string(rune('a'+id%10))
|
||||||
|
for j := 0; j < numOperations; j++ {
|
||||||
|
switch j % 6 {
|
||||||
|
case 0:
|
||||||
|
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
|
||||||
|
case 1:
|
||||||
|
cm.IsInCooldown(tokenKey)
|
||||||
|
case 2:
|
||||||
|
cm.GetRemainingCooldown(tokenKey)
|
||||||
|
case 3:
|
||||||
|
cm.GetCooldownReason(tokenKey)
|
||||||
|
case 4:
|
||||||
|
cm.ClearCooldown(tokenKey)
|
||||||
|
case 5:
|
||||||
|
cm.CleanupExpired()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCooldownReasonConstants(t *testing.T) {
|
||||||
|
if CooldownReason429 != "rate_limit_exceeded" {
|
||||||
|
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
|
||||||
|
}
|
||||||
|
if CooldownReasonSuspended != "account_suspended" {
|
||||||
|
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
|
||||||
|
}
|
||||||
|
if CooldownReasonQuotaExhausted != "quota_exhausted" {
|
||||||
|
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultConstants(t *testing.T) {
|
||||||
|
if DefaultShortCooldown != 1*time.Minute {
|
||||||
|
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
|
||||||
|
}
|
||||||
|
if MaxShortCooldown != 5*time.Minute {
|
||||||
|
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
|
||||||
|
}
|
||||||
|
if LongCooldown != 24*time.Hour {
|
||||||
|
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
|
||||||
|
cm := NewCooldownManager()
|
||||||
|
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
|
||||||
|
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||||
|
|
||||||
|
reason := cm.GetCooldownReason("token1")
|
||||||
|
if reason != CooldownReasonSuspended {
|
||||||
|
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining := cm.GetRemainingCooldown("token1")
|
||||||
|
if remaining > 1*time.Minute {
|
||||||
|
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
197
internal/auth/kiro/fingerprint.go
Normal file
197
internal/auth/kiro/fingerprint.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Fingerprint 多维度指纹信息
|
||||||
|
type Fingerprint struct {
|
||||||
|
SDKVersion string // 1.0.20-1.0.27
|
||||||
|
OSType string // darwin/windows/linux
|
||||||
|
OSVersion string // 10.0.22621
|
||||||
|
NodeVersion string // 18.x/20.x/22.x
|
||||||
|
KiroVersion string // 0.3.x-0.8.x
|
||||||
|
KiroHash string // SHA256
|
||||||
|
AcceptLanguage string
|
||||||
|
ScreenResolution string // 1920x1080
|
||||||
|
ColorDepth int // 24
|
||||||
|
HardwareConcurrency int // CPU 核心数
|
||||||
|
TimezoneOffset int
|
||||||
|
}
|
||||||
|
|
||||||
|
// FingerprintManager 指纹管理器
|
||||||
|
type FingerprintManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
||||||
|
rng *rand.Rand
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
sdkVersions = []string{
|
||||||
|
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
|
||||||
|
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
|
||||||
|
}
|
||||||
|
osTypes = []string{"darwin", "windows", "linux"}
|
||||||
|
osVersions = map[string][]string{
|
||||||
|
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
|
||||||
|
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
|
||||||
|
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
|
||||||
|
}
|
||||||
|
nodeVersions = []string{
|
||||||
|
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
|
||||||
|
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
|
||||||
|
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
|
||||||
|
}
|
||||||
|
kiroVersions = []string{
|
||||||
|
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
|
||||||
|
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
|
||||||
|
}
|
||||||
|
acceptLanguages = []string{
|
||||||
|
"en-US,en;q=0.9",
|
||||||
|
"en-GB,en;q=0.9",
|
||||||
|
"zh-CN,zh;q=0.9,en;q=0.8",
|
||||||
|
"zh-TW,zh;q=0.9,en;q=0.8",
|
||||||
|
"ja-JP,ja;q=0.9,en;q=0.8",
|
||||||
|
"ko-KR,ko;q=0.9,en;q=0.8",
|
||||||
|
"de-DE,de;q=0.9,en;q=0.8",
|
||||||
|
"fr-FR,fr;q=0.9,en;q=0.8",
|
||||||
|
}
|
||||||
|
screenResolutions = []string{
|
||||||
|
"1920x1080", "2560x1440", "3840x2160",
|
||||||
|
"1366x768", "1440x900", "1680x1050",
|
||||||
|
"2560x1600", "3440x1440",
|
||||||
|
}
|
||||||
|
colorDepths = []int{24, 32}
|
||||||
|
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
|
||||||
|
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewFingerprintManager 创建指纹管理器
|
||||||
|
func NewFingerprintManager() *FingerprintManager {
|
||||||
|
return &FingerprintManager{
|
||||||
|
fingerprints: make(map[string]*Fingerprint),
|
||||||
|
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFingerprint 获取或生成 Token 关联的指纹
|
||||||
|
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||||
|
fm.mu.RLock()
|
||||||
|
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||||
|
fm.mu.RUnlock()
|
||||||
|
return fp
|
||||||
|
}
|
||||||
|
fm.mu.RUnlock()
|
||||||
|
|
||||||
|
fm.mu.Lock()
|
||||||
|
defer fm.mu.Unlock()
|
||||||
|
|
||||||
|
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||||
|
return fp
|
||||||
|
}
|
||||||
|
|
||||||
|
fp := fm.generateFingerprint(tokenKey)
|
||||||
|
fm.fingerprints[tokenKey] = fp
|
||||||
|
return fp
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateFingerprint 生成新的指纹
|
||||||
|
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
||||||
|
osType := fm.randomChoice(osTypes)
|
||||||
|
osVersion := fm.randomChoice(osVersions[osType])
|
||||||
|
kiroVersion := fm.randomChoice(kiroVersions)
|
||||||
|
|
||||||
|
fp := &Fingerprint{
|
||||||
|
SDKVersion: fm.randomChoice(sdkVersions),
|
||||||
|
OSType: osType,
|
||||||
|
OSVersion: osVersion,
|
||||||
|
NodeVersion: fm.randomChoice(nodeVersions),
|
||||||
|
KiroVersion: kiroVersion,
|
||||||
|
AcceptLanguage: fm.randomChoice(acceptLanguages),
|
||||||
|
ScreenResolution: fm.randomChoice(screenResolutions),
|
||||||
|
ColorDepth: fm.randomIntChoice(colorDepths),
|
||||||
|
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
|
||||||
|
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
|
||||||
|
}
|
||||||
|
|
||||||
|
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
|
||||||
|
return fp
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKiroHash 生成 Kiro Hash
|
||||||
|
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
|
||||||
|
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
|
||||||
|
hash := sha256.Sum256([]byte(data))
|
||||||
|
return hex.EncodeToString(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomChoice 随机选择字符串
|
||||||
|
func (fm *FingerprintManager) randomChoice(choices []string) string {
|
||||||
|
return choices[fm.rng.Intn(len(choices))]
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomIntChoice 随机选择整数
|
||||||
|
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
|
||||||
|
return choices[fm.rng.Intn(len(choices))]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
|
||||||
|
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
|
||||||
|
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
|
||||||
|
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
|
||||||
|
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
|
||||||
|
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
|
||||||
|
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
|
||||||
|
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
|
||||||
|
req.Header.Set("Accept-Language", fp.AcceptLanguage)
|
||||||
|
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
|
||||||
|
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
|
||||||
|
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
|
||||||
|
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveFingerprint 移除 Token 关联的指纹
|
||||||
|
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
|
||||||
|
fm.mu.Lock()
|
||||||
|
defer fm.mu.Unlock()
|
||||||
|
delete(fm.fingerprints, tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count 返回当前管理的指纹数量
|
||||||
|
func (fm *FingerprintManager) Count() int {
|
||||||
|
fm.mu.RLock()
|
||||||
|
defer fm.mu.RUnlock()
|
||||||
|
return len(fm.fingerprints)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
|
||||||
|
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||||
|
func (fp *Fingerprint) BuildUserAgent() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||||
|
fp.SDKVersion,
|
||||||
|
fp.OSType,
|
||||||
|
fp.OSVersion,
|
||||||
|
fp.NodeVersion,
|
||||||
|
fp.SDKVersion,
|
||||||
|
fp.KiroVersion,
|
||||||
|
fp.KiroHash,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
|
||||||
|
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||||
|
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||||
|
fp.SDKVersion,
|
||||||
|
fp.KiroVersion,
|
||||||
|
fp.KiroHash,
|
||||||
|
)
|
||||||
|
}
|
||||||
227
internal/auth/kiro/fingerprint_test.go
Normal file
227
internal/auth/kiro/fingerprint_test.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewFingerprintManager(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
if fm == nil {
|
||||||
|
t.Fatal("expected non-nil FingerprintManager")
|
||||||
|
}
|
||||||
|
if fm.fingerprints == nil {
|
||||||
|
t.Error("expected non-nil fingerprints map")
|
||||||
|
}
|
||||||
|
if fm.rng == nil {
|
||||||
|
t.Error("expected non-nil rng")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFingerprint_NewToken(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fp := fm.GetFingerprint("token1")
|
||||||
|
|
||||||
|
if fp == nil {
|
||||||
|
t.Fatal("expected non-nil Fingerprint")
|
||||||
|
}
|
||||||
|
if fp.SDKVersion == "" {
|
||||||
|
t.Error("expected non-empty SDKVersion")
|
||||||
|
}
|
||||||
|
if fp.OSType == "" {
|
||||||
|
t.Error("expected non-empty OSType")
|
||||||
|
}
|
||||||
|
if fp.OSVersion == "" {
|
||||||
|
t.Error("expected non-empty OSVersion")
|
||||||
|
}
|
||||||
|
if fp.NodeVersion == "" {
|
||||||
|
t.Error("expected non-empty NodeVersion")
|
||||||
|
}
|
||||||
|
if fp.KiroVersion == "" {
|
||||||
|
t.Error("expected non-empty KiroVersion")
|
||||||
|
}
|
||||||
|
if fp.KiroHash == "" {
|
||||||
|
t.Error("expected non-empty KiroHash")
|
||||||
|
}
|
||||||
|
if fp.AcceptLanguage == "" {
|
||||||
|
t.Error("expected non-empty AcceptLanguage")
|
||||||
|
}
|
||||||
|
if fp.ScreenResolution == "" {
|
||||||
|
t.Error("expected non-empty ScreenResolution")
|
||||||
|
}
|
||||||
|
if fp.ColorDepth == 0 {
|
||||||
|
t.Error("expected non-zero ColorDepth")
|
||||||
|
}
|
||||||
|
if fp.HardwareConcurrency == 0 {
|
||||||
|
t.Error("expected non-zero HardwareConcurrency")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fp1 := fm.GetFingerprint("token1")
|
||||||
|
fp2 := fm.GetFingerprint("token1")
|
||||||
|
|
||||||
|
if fp1 != fp2 {
|
||||||
|
t.Error("expected same fingerprint for same token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFingerprint_DifferentTokens(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fp1 := fm.GetFingerprint("token1")
|
||||||
|
fp2 := fm.GetFingerprint("token2")
|
||||||
|
|
||||||
|
if fp1 == fp2 {
|
||||||
|
t.Error("expected different fingerprints for different tokens")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveFingerprint(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fm.GetFingerprint("token1")
|
||||||
|
if fm.Count() != 1 {
|
||||||
|
t.Fatalf("expected count 1, got %d", fm.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
fm.RemoveFingerprint("token1")
|
||||||
|
if fm.Count() != 0 {
|
||||||
|
t.Errorf("expected count 0, got %d", fm.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveFingerprint_NonExistent(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fm.RemoveFingerprint("nonexistent")
|
||||||
|
if fm.Count() != 0 {
|
||||||
|
t.Errorf("expected count 0, got %d", fm.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCount(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
if fm.Count() != 0 {
|
||||||
|
t.Errorf("expected count 0, got %d", fm.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
fm.GetFingerprint("token1")
|
||||||
|
fm.GetFingerprint("token2")
|
||||||
|
fm.GetFingerprint("token3")
|
||||||
|
|
||||||
|
if fm.Count() != 3 {
|
||||||
|
t.Errorf("expected count 3, got %d", fm.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyToRequest(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fp := fm.GetFingerprint("token1")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
fp.ApplyToRequest(req)
|
||||||
|
|
||||||
|
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
|
||||||
|
t.Error("X-Kiro-SDK-Version header mismatch")
|
||||||
|
}
|
||||||
|
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
|
||||||
|
t.Error("X-Kiro-OS-Type header mismatch")
|
||||||
|
}
|
||||||
|
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
|
||||||
|
t.Error("X-Kiro-OS-Version header mismatch")
|
||||||
|
}
|
||||||
|
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
|
||||||
|
t.Error("X-Kiro-Node-Version header mismatch")
|
||||||
|
}
|
||||||
|
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
|
||||||
|
t.Error("X-Kiro-Version header mismatch")
|
||||||
|
}
|
||||||
|
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
|
||||||
|
t.Error("X-Kiro-Hash header mismatch")
|
||||||
|
}
|
||||||
|
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
|
||||||
|
t.Error("Accept-Language header mismatch")
|
||||||
|
}
|
||||||
|
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
|
||||||
|
t.Error("X-Screen-Resolution header mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
|
||||||
|
validVersions := osVersions[fp.OSType]
|
||||||
|
found := false
|
||||||
|
for _, v := range validVersions {
|
||||||
|
if v == fp.OSVersion {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
const numGoroutines = 100
|
||||||
|
const numOperations = 100
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < numOperations; j++ {
|
||||||
|
tokenKey := "token" + string(rune('a'+id%26))
|
||||||
|
switch j % 4 {
|
||||||
|
case 0:
|
||||||
|
fm.GetFingerprint(tokenKey)
|
||||||
|
case 1:
|
||||||
|
fm.Count()
|
||||||
|
case 2:
|
||||||
|
fp := fm.GetFingerprint(tokenKey)
|
||||||
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
fp.ApplyToRequest(req)
|
||||||
|
case 3:
|
||||||
|
fm.RemoveFingerprint(tokenKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroHashUniqueness(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
hashes := make(map[string]bool)
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
fp := fm.GetFingerprint("token" + string(rune(i)))
|
||||||
|
if hashes[fp.KiroHash] {
|
||||||
|
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
|
||||||
|
}
|
||||||
|
hashes[fp.KiroHash] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroHashFormat(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fp := fm.GetFingerprint("token1")
|
||||||
|
|
||||||
|
if len(fp.KiroHash) != 64 {
|
||||||
|
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range fp.KiroHash {
|
||||||
|
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||||
|
t.Errorf("invalid hex character in KiroHash: %c", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
174
internal/auth/kiro/jitter.go
Normal file
174
internal/auth/kiro/jitter.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Jitter configuration constants
|
||||||
|
const (
|
||||||
|
// JitterPercent is the default percentage of jitter to apply (±30%)
|
||||||
|
JitterPercent = 0.30
|
||||||
|
|
||||||
|
// Human-like delay ranges
|
||||||
|
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
|
||||||
|
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
|
||||||
|
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
|
||||||
|
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
|
||||||
|
LongDelayMin = 5 * time.Second // Minimum for reading/resting
|
||||||
|
LongDelayMax = 10 * time.Second // Maximum for reading/resting
|
||||||
|
|
||||||
|
// Probability thresholds for human-like behavior
|
||||||
|
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
|
||||||
|
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
|
||||||
|
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
jitterRand *rand.Rand
|
||||||
|
jitterRandOnce sync.Once
|
||||||
|
jitterMu sync.Mutex
|
||||||
|
lastRequestTime time.Time
|
||||||
|
)
|
||||||
|
|
||||||
|
// initJitterRand initializes the random number generator for jitter calculations.
|
||||||
|
// Uses a time-based seed for unpredictable but reproducible randomness.
|
||||||
|
func initJitterRand() {
|
||||||
|
jitterRandOnce.Do(func() {
|
||||||
|
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomDelay generates a random delay between min and max duration.
|
||||||
|
// Thread-safe implementation using mutex protection.
|
||||||
|
func RandomDelay(min, max time.Duration) time.Duration {
|
||||||
|
initJitterRand()
|
||||||
|
jitterMu.Lock()
|
||||||
|
defer jitterMu.Unlock()
|
||||||
|
|
||||||
|
if min >= max {
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
|
||||||
|
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||||
|
randomMs := jitterRand.Int63n(rangeMs)
|
||||||
|
return min + time.Duration(randomMs)*time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
// JitterDelay adds jitter to a base delay.
|
||||||
|
// Applies ±jitterPercent variation to the base delay.
|
||||||
|
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
|
||||||
|
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
|
||||||
|
initJitterRand()
|
||||||
|
jitterMu.Lock()
|
||||||
|
defer jitterMu.Unlock()
|
||||||
|
|
||||||
|
if jitterPercent <= 0 || jitterPercent > 1 {
|
||||||
|
jitterPercent = JitterPercent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate jitter range: base * jitterPercent
|
||||||
|
jitterRange := float64(baseDelay) * jitterPercent
|
||||||
|
|
||||||
|
// Generate random value in range [-jitterRange, +jitterRange]
|
||||||
|
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
|
||||||
|
|
||||||
|
result := time.Duration(float64(baseDelay) + jitter)
|
||||||
|
if result < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// JitterDelayDefault applies the default ±30% jitter to a base delay.
|
||||||
|
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
|
||||||
|
return JitterDelay(baseDelay, JitterPercent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HumanLikeDelay generates a delay that mimics human behavior patterns.
|
||||||
|
// The delay is selected based on probability distribution:
|
||||||
|
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
|
||||||
|
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
|
||||||
|
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
|
||||||
|
//
|
||||||
|
// Returns the delay duration (caller should call time.Sleep with this value).
|
||||||
|
func HumanLikeDelay() time.Duration {
|
||||||
|
initJitterRand()
|
||||||
|
jitterMu.Lock()
|
||||||
|
defer jitterMu.Unlock()
|
||||||
|
|
||||||
|
// Track time since last request for adaptive behavior
|
||||||
|
now := time.Now()
|
||||||
|
timeSinceLastRequest := now.Sub(lastRequestTime)
|
||||||
|
lastRequestTime = now
|
||||||
|
|
||||||
|
// If requests are very close together, use short delay
|
||||||
|
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
|
||||||
|
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
|
||||||
|
randomMs := jitterRand.Int63n(rangeMs)
|
||||||
|
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, use probability-based selection
|
||||||
|
roll := jitterRand.Float64()
|
||||||
|
|
||||||
|
var min, max time.Duration
|
||||||
|
switch {
|
||||||
|
case roll < ShortDelayProbability:
|
||||||
|
// Short delay - consecutive operations
|
||||||
|
min, max = ShortDelayMin, ShortDelayMax
|
||||||
|
case roll < ShortDelayProbability+LongDelayProbability:
|
||||||
|
// Long delay - reading/resting
|
||||||
|
min, max = LongDelayMin, LongDelayMax
|
||||||
|
default:
|
||||||
|
// Normal delay - thinking time
|
||||||
|
min, max = NormalDelayMin, NormalDelayMax
|
||||||
|
}
|
||||||
|
|
||||||
|
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||||
|
randomMs := jitterRand.Int63n(rangeMs)
|
||||||
|
return min + time.Duration(randomMs)*time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyHumanLikeDelay applies human-like delay by sleeping.
|
||||||
|
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
|
||||||
|
func ApplyHumanLikeDelay() {
|
||||||
|
delay := HumanLikeDelay()
|
||||||
|
if delay > 0 {
|
||||||
|
time.Sleep(delay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
|
||||||
|
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
|
||||||
|
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
|
||||||
|
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||||
|
if attempt < 0 {
|
||||||
|
attempt = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate exponential backoff: baseDelay * 2^attempt
|
||||||
|
backoff := baseDelay * time.Duration(1<<uint(attempt))
|
||||||
|
if backoff > maxDelay {
|
||||||
|
backoff = maxDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add ±30% jitter
|
||||||
|
return JitterDelay(backoff, JitterPercent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldSkipDelay determines if delay should be skipped based on context.
|
||||||
|
// Returns true for streaming responses, WebSocket connections, etc.
|
||||||
|
// This function can be extended to check additional skip conditions.
|
||||||
|
func ShouldSkipDelay(isStreaming bool) bool {
|
||||||
|
return isStreaming
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetLastRequestTime resets the last request time tracker.
|
||||||
|
// Useful for testing or when starting a new session.
|
||||||
|
func ResetLastRequestTime() {
|
||||||
|
jitterMu.Lock()
|
||||||
|
defer jitterMu.Unlock()
|
||||||
|
lastRequestTime = time.Time{}
|
||||||
|
}
|
||||||
187
internal/auth/kiro/metrics.go
Normal file
187
internal/auth/kiro/metrics.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenMetrics holds performance metrics for a single token.
|
||||||
|
type TokenMetrics struct {
|
||||||
|
SuccessRate float64 // Success rate (0.0 - 1.0)
|
||||||
|
AvgLatency float64 // Average latency in milliseconds
|
||||||
|
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
|
||||||
|
LastUsed time.Time // Last usage timestamp
|
||||||
|
FailCount int // Consecutive failure count
|
||||||
|
TotalRequests int // Total request count
|
||||||
|
successCount int // Internal: successful request count
|
||||||
|
totalLatency float64 // Internal: cumulative latency
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenScorer manages token metrics and scoring.
|
||||||
|
type TokenScorer struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
metrics map[string]*TokenMetrics
|
||||||
|
|
||||||
|
// Scoring weights
|
||||||
|
successRateWeight float64
|
||||||
|
quotaWeight float64
|
||||||
|
latencyWeight float64
|
||||||
|
lastUsedWeight float64
|
||||||
|
failPenaltyMultiplier float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTokenScorer creates a new TokenScorer with default weights.
|
||||||
|
func NewTokenScorer() *TokenScorer {
|
||||||
|
return &TokenScorer{
|
||||||
|
metrics: make(map[string]*TokenMetrics),
|
||||||
|
successRateWeight: 0.4,
|
||||||
|
quotaWeight: 0.25,
|
||||||
|
latencyWeight: 0.2,
|
||||||
|
lastUsedWeight: 0.15,
|
||||||
|
failPenaltyMultiplier: 0.1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOrCreateMetrics returns existing metrics or creates new ones.
|
||||||
|
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
|
||||||
|
if m, ok := s.metrics[tokenKey]; ok {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
m := &TokenMetrics{
|
||||||
|
SuccessRate: 1.0,
|
||||||
|
QuotaRemaining: 1.0,
|
||||||
|
}
|
||||||
|
s.metrics[tokenKey] = m
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordRequest records the result of a request for a token.
|
||||||
|
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
m := s.getOrCreateMetrics(tokenKey)
|
||||||
|
m.TotalRequests++
|
||||||
|
m.LastUsed = time.Now()
|
||||||
|
m.totalLatency += float64(latency.Milliseconds())
|
||||||
|
|
||||||
|
if success {
|
||||||
|
m.successCount++
|
||||||
|
m.FailCount = 0
|
||||||
|
} else {
|
||||||
|
m.FailCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update derived metrics
|
||||||
|
if m.TotalRequests > 0 {
|
||||||
|
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
|
||||||
|
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuotaRemaining updates the remaining quota for a token.
|
||||||
|
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
m := s.getOrCreateMetrics(tokenKey)
|
||||||
|
m.QuotaRemaining = quota
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns a copy of the metrics for a token.
|
||||||
|
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
if m, ok := s.metrics[tokenKey]; ok {
|
||||||
|
copy := *m
|
||||||
|
return ©
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateScore computes the score for a token (higher is better).
|
||||||
|
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
m, ok := s.metrics[tokenKey]
|
||||||
|
if !ok {
|
||||||
|
return 1.0 // New tokens get a high initial score
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success rate component (0-1)
|
||||||
|
successScore := m.SuccessRate
|
||||||
|
|
||||||
|
// Quota component (0-1)
|
||||||
|
quotaScore := m.QuotaRemaining
|
||||||
|
|
||||||
|
// Latency component (normalized, lower is better)
|
||||||
|
// Using exponential decay: score = e^(-latency/1000)
|
||||||
|
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
|
||||||
|
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
|
||||||
|
if m.TotalRequests == 0 {
|
||||||
|
latencyScore = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last used component (prefer tokens not recently used)
|
||||||
|
// Score increases as time since last use increases
|
||||||
|
timeSinceUse := time.Since(m.LastUsed).Seconds()
|
||||||
|
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
|
||||||
|
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
|
||||||
|
if m.LastUsed.IsZero() {
|
||||||
|
lastUsedScore = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate weighted score
|
||||||
|
score := s.successRateWeight*successScore +
|
||||||
|
s.quotaWeight*quotaScore +
|
||||||
|
s.latencyWeight*latencyScore +
|
||||||
|
s.lastUsedWeight*lastUsedScore
|
||||||
|
|
||||||
|
// Apply consecutive failure penalty
|
||||||
|
if m.FailCount > 0 {
|
||||||
|
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
|
||||||
|
score = score * math.Max(0, 1.0-penalty)
|
||||||
|
}
|
||||||
|
|
||||||
|
return score
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectBestToken selects the token with the highest score.
|
||||||
|
func (s *TokenScorer) SelectBestToken(tokens []string) string {
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(tokens) == 1 {
|
||||||
|
return tokens[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
bestToken := tokens[0]
|
||||||
|
bestScore := s.CalculateScore(tokens[0])
|
||||||
|
|
||||||
|
for _, token := range tokens[1:] {
|
||||||
|
score := s.CalculateScore(token)
|
||||||
|
if score > bestScore {
|
||||||
|
bestScore = score
|
||||||
|
bestToken = token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetMetrics clears all metrics for a token.
|
||||||
|
func (s *TokenScorer) ResetMetrics(tokenKey string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.metrics, tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetAllMetrics clears all stored metrics.
|
||||||
|
func (s *TokenScorer) ResetAllMetrics() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.metrics = make(map[string]*TokenMetrics)
|
||||||
|
}
|
||||||
301
internal/auth/kiro/metrics_test.go
Normal file
301
internal/auth/kiro/metrics_test.go
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewTokenScorer(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("expected non-nil TokenScorer")
|
||||||
|
}
|
||||||
|
if s.metrics == nil {
|
||||||
|
t.Error("expected non-nil metrics map")
|
||||||
|
}
|
||||||
|
if s.successRateWeight != 0.4 {
|
||||||
|
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
|
||||||
|
}
|
||||||
|
if s.quotaWeight != 0.25 {
|
||||||
|
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordRequest_Success(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m == nil {
|
||||||
|
t.Fatal("expected non-nil metrics")
|
||||||
|
}
|
||||||
|
if m.TotalRequests != 1 {
|
||||||
|
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
|
||||||
|
}
|
||||||
|
if m.SuccessRate != 1.0 {
|
||||||
|
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
|
||||||
|
}
|
||||||
|
if m.FailCount != 0 {
|
||||||
|
t.Errorf("expected FailCount 0, got %d", m.FailCount)
|
||||||
|
}
|
||||||
|
if m.AvgLatency != 100 {
|
||||||
|
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordRequest_Failure(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", false, 200*time.Millisecond)
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m.SuccessRate != 0.0 {
|
||||||
|
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
|
||||||
|
}
|
||||||
|
if m.FailCount != 1 {
|
||||||
|
t.Errorf("expected FailCount 1, got %d", m.FailCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordRequest_MixedResults(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m.TotalRequests != 4 {
|
||||||
|
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
|
||||||
|
}
|
||||||
|
if m.SuccessRate != 0.75 {
|
||||||
|
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
|
||||||
|
}
|
||||||
|
if m.FailCount != 0 {
|
||||||
|
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m.FailCount != 3 {
|
||||||
|
t.Errorf("expected FailCount 3, got %d", m.FailCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetQuotaRemaining(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.SetQuotaRemaining("token1", 0.5)
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m.QuotaRemaining != 0.5 {
|
||||||
|
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMetrics_NonExistent(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
m := s.GetMetrics("nonexistent")
|
||||||
|
if m != nil {
|
||||||
|
t.Error("expected nil metrics for non-existent token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMetrics_ReturnsCopy(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
|
||||||
|
m1 := s.GetMetrics("token1")
|
||||||
|
m1.TotalRequests = 999
|
||||||
|
|
||||||
|
m2 := s.GetMetrics("token1")
|
||||||
|
if m2.TotalRequests == 999 {
|
||||||
|
t.Error("GetMetrics should return a copy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateScore_NewToken(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
score := s.CalculateScore("newtoken")
|
||||||
|
if score != 1.0 {
|
||||||
|
t.Errorf("expected score 1.0 for new token, got %f", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateScore_PerfectToken(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 50*time.Millisecond)
|
||||||
|
s.SetQuotaRemaining("token1", 1.0)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
score := s.CalculateScore("token1")
|
||||||
|
if score < 0.5 || score > 1.0 {
|
||||||
|
t.Errorf("expected high score for perfect token, got %f", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateScore_FailedToken(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
s.RecordRequest("token1", false, 1000*time.Millisecond)
|
||||||
|
}
|
||||||
|
s.SetQuotaRemaining("token1", 0.1)
|
||||||
|
|
||||||
|
score := s.CalculateScore("token1")
|
||||||
|
if score > 0.5 {
|
||||||
|
t.Errorf("expected low score for failed token, got %f", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateScore_FailPenalty(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
scoreNoFail := s.CalculateScore("token1")
|
||||||
|
|
||||||
|
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||||
|
scoreWithFail := s.CalculateScore("token1")
|
||||||
|
|
||||||
|
if scoreWithFail >= scoreNoFail {
|
||||||
|
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectBestToken_Empty(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
best := s.SelectBestToken([]string{})
|
||||||
|
if best != "" {
|
||||||
|
t.Errorf("expected empty string for empty tokens, got %s", best)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectBestToken_SingleToken(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
best := s.SelectBestToken([]string{"token1"})
|
||||||
|
if best != "token1" {
|
||||||
|
t.Errorf("expected token1, got %s", best)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectBestToken_MultipleTokens(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
|
||||||
|
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||||
|
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||||
|
s.SetQuotaRemaining("bad", 0.1)
|
||||||
|
|
||||||
|
s.RecordRequest("good", true, 50*time.Millisecond)
|
||||||
|
s.SetQuotaRemaining("good", 0.9)
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
best := s.SelectBestToken([]string{"bad", "good"})
|
||||||
|
if best != "good" {
|
||||||
|
t.Errorf("expected good token to be selected, got %s", best)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetMetrics(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
s.ResetMetrics("token1")
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m != nil {
|
||||||
|
t.Error("expected nil metrics after reset")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetAllMetrics(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token2", true, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token3", true, 100*time.Millisecond)
|
||||||
|
|
||||||
|
s.ResetAllMetrics()
|
||||||
|
|
||||||
|
if s.GetMetrics("token1") != nil {
|
||||||
|
t.Error("expected nil metrics for token1 after reset all")
|
||||||
|
}
|
||||||
|
if s.GetMetrics("token2") != nil {
|
||||||
|
t.Error("expected nil metrics for token2 after reset all")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
const numGoroutines = 50
|
||||||
|
const numOperations = 100
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
tokenKey := "token" + string(rune('a'+id%10))
|
||||||
|
for j := 0; j < numOperations; j++ {
|
||||||
|
switch j % 6 {
|
||||||
|
case 0:
|
||||||
|
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
|
||||||
|
case 1:
|
||||||
|
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
|
||||||
|
case 2:
|
||||||
|
s.GetMetrics(tokenKey)
|
||||||
|
case 3:
|
||||||
|
s.CalculateScore(tokenKey)
|
||||||
|
case 4:
|
||||||
|
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
|
||||||
|
case 5:
|
||||||
|
if j%20 == 0 {
|
||||||
|
s.ResetMetrics(tokenKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAvgLatencyCalculation(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", true, 200*time.Millisecond)
|
||||||
|
s.RecordRequest("token1", true, 300*time.Millisecond)
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m.AvgLatency != 200 {
|
||||||
|
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLastUsedUpdated(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
before := time.Now()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m.LastUsed.Before(before) {
|
||||||
|
t.Error("expected LastUsed to be after test start time")
|
||||||
|
}
|
||||||
|
if m.LastUsed.After(time.Now()) {
|
||||||
|
t.Error("expected LastUsed to be before or equal to now")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultQuotaForNewToken(t *testing.T) {
|
||||||
|
s := NewTokenScorer()
|
||||||
|
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||||
|
|
||||||
|
m := s.GetMetrics("token1")
|
||||||
|
if m.QuotaRemaining != 1.0 {
|
||||||
|
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -227,6 +227,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
|
|||||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
AuthMethod: "social",
|
AuthMethod: "social",
|
||||||
Provider: "", // Caller should preserve original provider
|
Provider: "", // Caller should preserve original provider
|
||||||
|
Region: "us-east-1",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,6 +286,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
|
|||||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
AuthMethod: "social",
|
AuthMethod: "social",
|
||||||
Provider: "", // Caller should preserve original provider
|
Provider: "", // Caller should preserve original provider
|
||||||
|
Region: "us-east-1",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
982
internal/auth/kiro/oauth_web.go
Normal file
982
internal/auth/kiro/oauth_web.go
Normal file
@@ -0,0 +1,982 @@
|
|||||||
|
// Package kiro provides OAuth Web authentication for Kiro.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultSessionExpiry = 10 * time.Minute
|
||||||
|
pollIntervalSeconds = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
type authSessionStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
statusPending authSessionStatus = "pending"
|
||||||
|
statusSuccess authSessionStatus = "success"
|
||||||
|
statusFailed authSessionStatus = "failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
type webAuthSession struct {
|
||||||
|
stateID string
|
||||||
|
deviceCode string
|
||||||
|
userCode string
|
||||||
|
authURL string
|
||||||
|
verificationURI string
|
||||||
|
expiresIn int
|
||||||
|
interval int
|
||||||
|
status authSessionStatus
|
||||||
|
startedAt time.Time
|
||||||
|
completedAt time.Time
|
||||||
|
expiresAt time.Time
|
||||||
|
error string
|
||||||
|
tokenData *KiroTokenData
|
||||||
|
ssoClient *SSOOIDCClient
|
||||||
|
clientID string
|
||||||
|
clientSecret string
|
||||||
|
region string
|
||||||
|
cancelFunc context.CancelFunc
|
||||||
|
authMethod string // "google", "github", "builder-id", "idc"
|
||||||
|
startURL string // Used for IDC
|
||||||
|
codeVerifier string // Used for social auth PKCE
|
||||||
|
codeChallenge string // Used for social auth PKCE
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthWebHandler struct {
|
||||||
|
cfg *config.Config
|
||||||
|
sessions map[string]*webAuthSession
|
||||||
|
mu sync.RWMutex
|
||||||
|
onTokenObtained func(*KiroTokenData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||||
|
return &OAuthWebHandler{
|
||||||
|
cfg: cfg,
|
||||||
|
sessions: make(map[string]*webAuthSession),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
||||||
|
h.onTokenObtained = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
||||||
|
oauth := router.Group("/v0/oauth/kiro")
|
||||||
|
{
|
||||||
|
oauth.GET("", h.handleSelect)
|
||||||
|
oauth.GET("/start", h.handleStart)
|
||||||
|
oauth.GET("/callback", h.handleCallback)
|
||||||
|
oauth.GET("/social/callback", h.handleSocialCallback)
|
||||||
|
oauth.GET("/status", h.handleStatus)
|
||||||
|
oauth.POST("/import", h.handleImportToken)
|
||||||
|
oauth.POST("/refresh", h.handleManualRefresh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateStateID() (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
|
||||||
|
h.renderSelectPage(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||||
|
method := c.Query("method")
|
||||||
|
|
||||||
|
if method == "" {
|
||||||
|
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch method {
|
||||||
|
case "google", "github":
|
||||||
|
// Google/GitHub social login is not supported for third-party apps
|
||||||
|
// due to AWS Cognito redirect_uri restrictions
|
||||||
|
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
|
||||||
|
case "builder-id":
|
||||||
|
h.startBuilderIDAuth(c)
|
||||||
|
case "idc":
|
||||||
|
h.startIDCAuth(c)
|
||||||
|
default:
|
||||||
|
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
|
||||||
|
stateID, err := generateStateID()
|
||||||
|
if err != nil {
|
||||||
|
h.renderError(c, "Failed to generate state parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
codeVerifier, codeChallenge, err := generatePKCE()
|
||||||
|
if err != nil {
|
||||||
|
h.renderError(c, "Failed to generate PKCE parameters")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
socialClient := NewSocialAuthClient(h.cfg)
|
||||||
|
|
||||||
|
var provider string
|
||||||
|
if method == "google" {
|
||||||
|
provider = string(ProviderGoogle)
|
||||||
|
} else {
|
||||||
|
provider = string(ProviderGitHub)
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := h.getSocialCallbackURL(c)
|
||||||
|
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||||
|
|
||||||
|
session := &webAuthSession{
|
||||||
|
stateID: stateID,
|
||||||
|
authMethod: method,
|
||||||
|
authURL: authURL,
|
||||||
|
status: statusPending,
|
||||||
|
startedAt: time.Now(),
|
||||||
|
expiresIn: 600,
|
||||||
|
codeVerifier: codeVerifier,
|
||||||
|
codeChallenge: codeChallenge,
|
||||||
|
region: "us-east-1",
|
||||||
|
cancelFunc: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
h.sessions[stateID] = session
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
h.mu.Lock()
|
||||||
|
if session.status == statusPending {
|
||||||
|
session.status = statusFailed
|
||||||
|
session.error = "Authentication timed out"
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.Redirect(http.StatusFound, authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
|
||||||
|
scheme := "http"
|
||||||
|
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
|
||||||
|
stateID, err := generateStateID()
|
||||||
|
if err != nil {
|
||||||
|
h.renderError(c, "Failed to generate state parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
region := defaultIDCRegion
|
||||||
|
startURL := builderIDStartURL
|
||||||
|
|
||||||
|
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||||
|
|
||||||
|
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||||
|
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||||
|
c.Request.Context(),
|
||||||
|
regResp.ClientID,
|
||||||
|
regResp.ClientSecret,
|
||||||
|
startURL,
|
||||||
|
region,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||||
|
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||||
|
|
||||||
|
session := &webAuthSession{
|
||||||
|
stateID: stateID,
|
||||||
|
deviceCode: authResp.DeviceCode,
|
||||||
|
userCode: authResp.UserCode,
|
||||||
|
authURL: authResp.VerificationURIComplete,
|
||||||
|
verificationURI: authResp.VerificationURI,
|
||||||
|
expiresIn: authResp.ExpiresIn,
|
||||||
|
interval: authResp.Interval,
|
||||||
|
status: statusPending,
|
||||||
|
startedAt: time.Now(),
|
||||||
|
ssoClient: ssoClient,
|
||||||
|
clientID: regResp.ClientID,
|
||||||
|
clientSecret: regResp.ClientSecret,
|
||||||
|
region: region,
|
||||||
|
authMethod: "builder-id",
|
||||||
|
startURL: startURL,
|
||||||
|
cancelFunc: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
h.sessions[stateID] = session
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
go h.pollForToken(ctx, session)
|
||||||
|
|
||||||
|
h.renderStartPage(c, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
|
||||||
|
startURL := c.Query("startUrl")
|
||||||
|
region := c.Query("region")
|
||||||
|
|
||||||
|
if startURL == "" {
|
||||||
|
h.renderError(c, "Missing startUrl parameter for IDC authentication")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
|
||||||
|
stateID, err := generateStateID()
|
||||||
|
if err != nil {
|
||||||
|
h.renderError(c, "Failed to generate state parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||||
|
|
||||||
|
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||||
|
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||||
|
c.Request.Context(),
|
||||||
|
regResp.ClientID,
|
||||||
|
regResp.ClientSecret,
|
||||||
|
startURL,
|
||||||
|
region,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||||
|
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||||
|
|
||||||
|
session := &webAuthSession{
|
||||||
|
stateID: stateID,
|
||||||
|
deviceCode: authResp.DeviceCode,
|
||||||
|
userCode: authResp.UserCode,
|
||||||
|
authURL: authResp.VerificationURIComplete,
|
||||||
|
verificationURI: authResp.VerificationURI,
|
||||||
|
expiresIn: authResp.ExpiresIn,
|
||||||
|
interval: authResp.Interval,
|
||||||
|
status: statusPending,
|
||||||
|
startedAt: time.Now(),
|
||||||
|
ssoClient: ssoClient,
|
||||||
|
clientID: regResp.ClientID,
|
||||||
|
clientSecret: regResp.ClientSecret,
|
||||||
|
region: region,
|
||||||
|
authMethod: "idc",
|
||||||
|
startURL: startURL,
|
||||||
|
cancelFunc: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
h.sessions[stateID] = session
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
go h.pollForToken(ctx, session)
|
||||||
|
|
||||||
|
h.renderStartPage(c, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
||||||
|
defer session.cancelFunc()
|
||||||
|
|
||||||
|
interval := time.Duration(session.interval) * time.Second
|
||||||
|
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
||||||
|
interval = time.Duration(pollIntervalSeconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
h.mu.Lock()
|
||||||
|
if session.status == statusPending {
|
||||||
|
session.status = statusFailed
|
||||||
|
session.error = "Authentication timed out"
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
||||||
|
ctx,
|
||||||
|
session.clientID,
|
||||||
|
session.clientSecret,
|
||||||
|
session.deviceCode,
|
||||||
|
session.region,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if errStr == ErrAuthorizationPending.Error() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errStr == ErrSlowDown.Error() {
|
||||||
|
interval += 5 * time.Second
|
||||||
|
ticker.Reset(interval)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
session.status = statusFailed
|
||||||
|
session.error = errStr
|
||||||
|
session.completedAt = time.Now()
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
log.Errorf("OAuth Web: token polling failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||||
|
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
||||||
|
|
||||||
|
tokenData := &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: profileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: session.authMethod,
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: session.clientID,
|
||||||
|
ClientSecret: session.clientSecret,
|
||||||
|
Email: email,
|
||||||
|
Region: session.region,
|
||||||
|
StartURL: session.startURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
session.status = statusSuccess
|
||||||
|
session.completedAt = time.Now()
|
||||||
|
session.expiresAt = expiresAt
|
||||||
|
session.tokenData = tokenData
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
if h.onTokenObtained != nil {
|
||||||
|
h.onTokenObtained(tokenData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save token to file
|
||||||
|
h.saveTokenToFile(tokenData)
|
||||||
|
|
||||||
|
log.Infof("OAuth Web: authentication successful for %s", email)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveTokenToFile saves the token data to the auth directory
|
||||||
|
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
||||||
|
// Get auth directory from config or use default
|
||||||
|
authDir := ""
|
||||||
|
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||||
|
var err error
|
||||||
|
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to default location
|
||||||
|
if authDir == "" {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to get home directory: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create directory if not exists
|
||||||
|
if err := os.MkdirAll(authDir, 0700); err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate filename based on auth method
|
||||||
|
// Format: kiro-{authMethod}.json or kiro-{authMethod}-{email}.json
|
||||||
|
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
// Sanitize email for filename (replace @ and . with -)
|
||||||
|
sanitizedEmail := tokenData.Email
|
||||||
|
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
|
||||||
|
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||||
|
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
|
||||||
|
}
|
||||||
|
|
||||||
|
authFilePath := filepath.Join(authDir, fileName)
|
||||||
|
|
||||||
|
// Convert to storage format and save
|
||||||
|
storage := &KiroTokenStorage{
|
||||||
|
Type: "kiro",
|
||||||
|
AccessToken: tokenData.AccessToken,
|
||||||
|
RefreshToken: tokenData.RefreshToken,
|
||||||
|
ProfileArn: tokenData.ProfileArn,
|
||||||
|
ExpiresAt: tokenData.ExpiresAt,
|
||||||
|
AuthMethod: tokenData.AuthMethod,
|
||||||
|
Provider: tokenData.Provider,
|
||||||
|
LastRefresh: time.Now().Format(time.RFC3339),
|
||||||
|
ClientID: tokenData.ClientID,
|
||||||
|
ClientSecret: tokenData.ClientSecret,
|
||||||
|
Region: tokenData.Region,
|
||||||
|
StartURL: tokenData.StartURL,
|
||||||
|
Email: tokenData.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
||||||
|
return session.ssoClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
||||||
|
stateID := c.Query("state")
|
||||||
|
errParam := c.Query("error")
|
||||||
|
|
||||||
|
if errParam != "" {
|
||||||
|
h.renderError(c, errParam)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stateID == "" {
|
||||||
|
h.renderError(c, "Missing state parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.RLock()
|
||||||
|
session, exists := h.sessions[stateID]
|
||||||
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
h.renderError(c, "Invalid or expired session")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.status == statusSuccess {
|
||||||
|
h.renderSuccess(c, session)
|
||||||
|
} else if session.status == statusFailed {
|
||||||
|
h.renderError(c, session.error)
|
||||||
|
} else {
|
||||||
|
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
|
||||||
|
stateID := c.Query("state")
|
||||||
|
code := c.Query("code")
|
||||||
|
errParam := c.Query("error")
|
||||||
|
|
||||||
|
if errParam != "" {
|
||||||
|
h.renderError(c, errParam)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stateID == "" {
|
||||||
|
h.renderError(c, "Missing state parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if code == "" {
|
||||||
|
h.renderError(c, "Missing authorization code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.RLock()
|
||||||
|
session, exists := h.sessions[stateID]
|
||||||
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
h.renderError(c, "Invalid or expired session")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.authMethod != "google" && session.authMethod != "github" {
|
||||||
|
h.renderError(c, "Invalid session type for social callback")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
socialClient := NewSocialAuthClient(h.cfg)
|
||||||
|
redirectURI := h.getSocialCallbackURL(c)
|
||||||
|
|
||||||
|
tokenReq := &CreateTokenRequest{
|
||||||
|
Code: code,
|
||||||
|
CodeVerifier: session.codeVerifier,
|
||||||
|
RedirectURI: redirectURI,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: social token exchange failed: %v", err)
|
||||||
|
h.mu.Lock()
|
||||||
|
session.status = statusFailed
|
||||||
|
session.error = fmt.Sprintf("Token exchange failed: %v", err)
|
||||||
|
session.completedAt = time.Now()
|
||||||
|
h.mu.Unlock()
|
||||||
|
h.renderError(c, session.error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
|
||||||
|
var provider string
|
||||||
|
if session.authMethod == "google" {
|
||||||
|
provider = string(ProviderGoogle)
|
||||||
|
} else {
|
||||||
|
provider = string(ProviderGitHub)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenData := &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: session.authMethod,
|
||||||
|
Provider: provider,
|
||||||
|
Email: email,
|
||||||
|
Region: "us-east-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
session.status = statusSuccess
|
||||||
|
session.completedAt = time.Now()
|
||||||
|
session.expiresAt = expiresAt
|
||||||
|
session.tokenData = tokenData
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
if session.cancelFunc != nil {
|
||||||
|
session.cancelFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.onTokenObtained != nil {
|
||||||
|
h.onTokenObtained(tokenData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save token to file
|
||||||
|
h.saveTokenToFile(tokenData)
|
||||||
|
|
||||||
|
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
|
||||||
|
h.renderSuccess(c, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
||||||
|
stateID := c.Query("state")
|
||||||
|
if stateID == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.RLock()
|
||||||
|
session, exists := h.sessions[stateID]
|
||||||
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := gin.H{
|
||||||
|
"status": string(session.status),
|
||||||
|
}
|
||||||
|
|
||||||
|
switch session.status {
|
||||||
|
case statusPending:
|
||||||
|
elapsed := time.Since(session.startedAt).Seconds()
|
||||||
|
remaining := float64(session.expiresIn) - elapsed
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
response["remaining_seconds"] = int(remaining)
|
||||||
|
case statusSuccess:
|
||||||
|
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
||||||
|
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
||||||
|
case statusFailed:
|
||||||
|
response["error"] = session.error
|
||||||
|
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
||||||
|
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
||||||
|
c.String(http.StatusInternalServerError, "Template error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"AuthURL": session.authURL,
|
||||||
|
"UserCode": session.userCode,
|
||||||
|
"ExpiresIn": session.expiresIn,
|
||||||
|
"StateID": session.stateID,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to render template: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
|
||||||
|
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to parse select template: %v", err)
|
||||||
|
c.String(http.StatusInternalServerError, "Template error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := tmpl.Execute(c.Writer, nil); err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to render select template: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
||||||
|
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
||||||
|
c.String(http.StatusInternalServerError, "Template error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"Error": errMsg,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.Status(http.StatusBadRequest)
|
||||||
|
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
||||||
|
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
||||||
|
c.String(http.StatusInternalServerError, "Template error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for id, session := range h.sessions {
|
||||||
|
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
||||||
|
delete(h.sessions, id)
|
||||||
|
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
||||||
|
session.cancelFunc()
|
||||||
|
delete(h.sessions, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
session, exists := h.sessions[stateID]
|
||||||
|
return session, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImportTokenRequest represents the request body for token import
|
||||||
|
type ImportTokenRequest struct {
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleImportToken handles manual refresh token import from Kiro IDE
|
||||||
|
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
|
||||||
|
var req ImportTokenRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": "Invalid request body",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||||
|
if refreshToken == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": "Refresh token is required",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate token format
|
||||||
|
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": "Invalid token format. Token should start with aorAAAAAG...",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create social auth client to refresh and validate the token
|
||||||
|
socialClient := NewSocialAuthClient(h.cfg)
|
||||||
|
|
||||||
|
// Refresh the token to validate it and get access token
|
||||||
|
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": fmt.Sprintf("Token validation failed: %v", err),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the original refresh token (the refreshed one might be empty)
|
||||||
|
if tokenData.RefreshToken == "" {
|
||||||
|
tokenData.RefreshToken = refreshToken
|
||||||
|
}
|
||||||
|
tokenData.AuthMethod = "social"
|
||||||
|
tokenData.Provider = "imported"
|
||||||
|
|
||||||
|
// Notify callback if set
|
||||||
|
if h.onTokenObtained != nil {
|
||||||
|
h.onTokenObtained(tokenData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save token to file
|
||||||
|
h.saveTokenToFile(tokenData)
|
||||||
|
|
||||||
|
// Generate filename for response
|
||||||
|
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
sanitizedEmail := strings.ReplaceAll(tokenData.Email, "@", "-")
|
||||||
|
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||||
|
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("OAuth Web: token imported successfully")
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "Token imported successfully",
|
||||||
|
"fileName": fileName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleManualRefresh handles manual token refresh requests from the web UI.
|
||||||
|
// This allows users to trigger a token refresh when needed, without waiting
|
||||||
|
// for the automatic 30-second check and 20-minute-before-expiry refresh cycle.
|
||||||
|
// Uses the same refresh logic as kiro_executor.Refresh for consistency.
|
||||||
|
func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) {
|
||||||
|
authDir := ""
|
||||||
|
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||||
|
var err error
|
||||||
|
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if authDir == "" {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": "Failed to get home directory",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find all kiro token files in the auth directory
|
||||||
|
files, err := os.ReadDir(authDir)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": fmt.Sprintf("Failed to read auth directory: %v", err),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var refreshedCount int
|
||||||
|
var errors []string
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
if file.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := file.Name()
|
||||||
|
if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(authDir, name)
|
||||||
|
data, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var storage KiroTokenStorage
|
||||||
|
if err := json.Unmarshal(data, &storage); err != nil {
|
||||||
|
errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if storage.RefreshToken == "" {
|
||||||
|
errors = append(errors, fmt.Sprintf("%s: no refresh token", name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh token using the same logic as kiro_executor.Refresh
|
||||||
|
tokenData, err := h.refreshTokenData(c.Request.Context(), &storage)
|
||||||
|
if err != nil {
|
||||||
|
errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update storage with new token data
|
||||||
|
storage.AccessToken = tokenData.AccessToken
|
||||||
|
if tokenData.RefreshToken != "" {
|
||||||
|
storage.RefreshToken = tokenData.RefreshToken
|
||||||
|
}
|
||||||
|
storage.ExpiresAt = tokenData.ExpiresAt
|
||||||
|
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||||
|
if tokenData.ProfileArn != "" {
|
||||||
|
storage.ProfileArn = tokenData.ProfileArn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write updated token back to file
|
||||||
|
updatedData, err := json.MarshalIndent(storage, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpFile := filePath + ".tmp"
|
||||||
|
if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil {
|
||||||
|
errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := os.Rename(tmpFile, filePath); err != nil {
|
||||||
|
errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt)
|
||||||
|
refreshedCount++
|
||||||
|
|
||||||
|
// Notify callback if set
|
||||||
|
if h.onTokenObtained != nil {
|
||||||
|
h.onTokenObtained(tokenData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshedCount == 0 && len(errors) > 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": fmt.Sprintf("All refresh attempts failed: %v", errors),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount),
|
||||||
|
"refreshedCount": refreshedCount,
|
||||||
|
}
|
||||||
|
if len(errors) > 0 {
|
||||||
|
response["warnings"] = errors
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshTokenData refreshes a token using the appropriate method based on auth type.
|
||||||
|
// This mirrors the logic in kiro_executor.Refresh for consistency.
|
||||||
|
func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) {
|
||||||
|
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "":
|
||||||
|
// IDC refresh with region-specific endpoint
|
||||||
|
log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region)
|
||||||
|
return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL)
|
||||||
|
|
||||||
|
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id":
|
||||||
|
// Builder ID refresh with default endpoint
|
||||||
|
log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID")
|
||||||
|
return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
|
||||||
|
log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint")
|
||||||
|
oauth := NewKiroOAuth(h.cfg)
|
||||||
|
return oauth.RefreshToken(ctx, storage.RefreshToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
779
internal/auth/kiro/oauth_web_templates.go
Normal file
779
internal/auth/kiro/oauth_web_templates.go
Normal file
@@ -0,0 +1,779 @@
|
|||||||
|
// Package kiro provides OAuth Web authentication templates.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
const (
|
||||||
|
oauthWebStartPageHTML = `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>AWS SSO Authentication</title>
|
||||||
|
<style>
|
||||||
|
* { box-sizing: border-box; }
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||||
|
margin: 0;
|
||||||
|
padding: 20px;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
min-height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
max-width: 500px;
|
||||||
|
width: 100%;
|
||||||
|
background: #fff;
|
||||||
|
padding: 40px;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
margin: 0 0 10px;
|
||||||
|
color: #333;
|
||||||
|
font-size: 24px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.subtitle {
|
||||||
|
text-align: center;
|
||||||
|
color: #666;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
}
|
||||||
|
.step {
|
||||||
|
background: #f8f9fa;
|
||||||
|
padding: 20px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
.step-title {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
font-weight: 600;
|
||||||
|
color: #333;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
.step-number {
|
||||||
|
width: 28px;
|
||||||
|
height: 28px;
|
||||||
|
background: #667eea;
|
||||||
|
color: white;
|
||||||
|
border-radius: 50%;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-right: 12px;
|
||||||
|
}
|
||||||
|
.user-code {
|
||||||
|
background: #e7f3ff;
|
||||||
|
border: 2px dashed #2196F3;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 20px;
|
||||||
|
text-align: center;
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
.user-code-label {
|
||||||
|
font-size: 12px;
|
||||||
|
color: #666;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 1px;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
}
|
||||||
|
.user-code-value {
|
||||||
|
font-size: 32px;
|
||||||
|
font-weight: bold;
|
||||||
|
font-family: monospace;
|
||||||
|
color: #2196F3;
|
||||||
|
letter-spacing: 4px;
|
||||||
|
}
|
||||||
|
.auth-btn {
|
||||||
|
display: block;
|
||||||
|
width: 100%;
|
||||||
|
padding: 15px;
|
||||||
|
background: #667eea;
|
||||||
|
color: white;
|
||||||
|
text-align: center;
|
||||||
|
text-decoration: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 16px;
|
||||||
|
transition: all 0.3s;
|
||||||
|
border: none;
|
||||||
|
cursor: pointer;
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
.auth-btn:hover {
|
||||||
|
background: #5568d3;
|
||||||
|
transform: translateY(-2px);
|
||||||
|
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||||
|
}
|
||||||
|
.status {
|
||||||
|
margin-top: 30px;
|
||||||
|
padding: 20px;
|
||||||
|
background: #f8f9fa;
|
||||||
|
border-radius: 8px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.status-pending { border-left: 4px solid #ffc107; }
|
||||||
|
.status-success { border-left: 4px solid #28a745; }
|
||||||
|
.status-failed { border-left: 4px solid #dc3545; }
|
||||||
|
.spinner {
|
||||||
|
border: 3px solid #f3f3f3;
|
||||||
|
border-top: 3px solid #667eea;
|
||||||
|
border-radius: 50%;
|
||||||
|
width: 40px;
|
||||||
|
height: 40px;
|
||||||
|
animation: spin 1s linear infinite;
|
||||||
|
margin: 0 auto 15px;
|
||||||
|
}
|
||||||
|
@keyframes spin {
|
||||||
|
0% { transform: rotate(0deg); }
|
||||||
|
100% { transform: rotate(360deg); }
|
||||||
|
}
|
||||||
|
.timer {
|
||||||
|
font-size: 24px;
|
||||||
|
font-weight: bold;
|
||||||
|
color: #667eea;
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
.timer.warning { color: #ffc107; }
|
||||||
|
.timer.danger { color: #dc3545; }
|
||||||
|
.status-message { color: #666; line-height: 1.6; }
|
||||||
|
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
|
||||||
|
.info-box {
|
||||||
|
background: #e7f3ff;
|
||||||
|
border-left: 4px solid #2196F3;
|
||||||
|
padding: 15px;
|
||||||
|
margin-top: 20px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 14px;
|
||||||
|
color: #666;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>🔐 AWS SSO Authentication</h1>
|
||||||
|
<p class="subtitle">Follow the steps below to complete authentication</p>
|
||||||
|
|
||||||
|
<div class="step">
|
||||||
|
<div class="step-title">
|
||||||
|
<span class="step-number">1</span>
|
||||||
|
Click the button below to open the authorization page
|
||||||
|
</div>
|
||||||
|
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
|
||||||
|
🚀 Open Authorization Page
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="step">
|
||||||
|
<div class="step-title">
|
||||||
|
<span class="step-number">2</span>
|
||||||
|
Enter the verification code below
|
||||||
|
</div>
|
||||||
|
<div class="user-code">
|
||||||
|
<div class="user-code-label">Verification Code</div>
|
||||||
|
<div class="user-code-value">{{.UserCode}}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="step">
|
||||||
|
<div class="step-title">
|
||||||
|
<span class="step-number">3</span>
|
||||||
|
Complete AWS SSO login
|
||||||
|
</div>
|
||||||
|
<p style="color: #666; font-size: 14px; margin-top: 10px;">
|
||||||
|
Use your AWS SSO account to login and authorize
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="status status-pending" id="statusBox">
|
||||||
|
<div class="spinner" id="spinner"></div>
|
||||||
|
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
|
||||||
|
<div class="status-message" id="statusMessage">
|
||||||
|
Waiting for authorization...
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="info-box">
|
||||||
|
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
let pollInterval;
|
||||||
|
let timerInterval;
|
||||||
|
let remainingSeconds = {{.ExpiresIn}};
|
||||||
|
const stateID = "{{.StateID}}";
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
document.getElementById('authBtn').click();
|
||||||
|
}, 500);
|
||||||
|
|
||||||
|
function pollStatus() {
|
||||||
|
fetch('/v0/oauth/kiro/status?state=' + stateID)
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
console.log('Status:', data);
|
||||||
|
if (data.status === 'success') {
|
||||||
|
clearInterval(pollInterval);
|
||||||
|
clearInterval(timerInterval);
|
||||||
|
showSuccess(data);
|
||||||
|
} else if (data.status === 'failed') {
|
||||||
|
clearInterval(pollInterval);
|
||||||
|
clearInterval(timerInterval);
|
||||||
|
showError(data);
|
||||||
|
} else {
|
||||||
|
remainingSeconds = data.remaining_seconds || 0;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch(error => {
|
||||||
|
console.error('Poll error:', error);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateTimer() {
|
||||||
|
const timerEl = document.getElementById('timer');
|
||||||
|
const minutes = Math.floor(remainingSeconds / 60);
|
||||||
|
const seconds = remainingSeconds % 60;
|
||||||
|
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
|
||||||
|
|
||||||
|
if (remainingSeconds < 60) {
|
||||||
|
timerEl.className = 'timer danger';
|
||||||
|
} else if (remainingSeconds < 180) {
|
||||||
|
timerEl.className = 'timer warning';
|
||||||
|
} else {
|
||||||
|
timerEl.className = 'timer';
|
||||||
|
}
|
||||||
|
|
||||||
|
remainingSeconds--;
|
||||||
|
|
||||||
|
if (remainingSeconds < 0) {
|
||||||
|
clearInterval(timerInterval);
|
||||||
|
clearInterval(pollInterval);
|
||||||
|
showError({ error: 'Authentication timed out. Please refresh and try again.' });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function showSuccess(data) {
|
||||||
|
const statusBox = document.getElementById('statusBox');
|
||||||
|
statusBox.className = 'status status-success';
|
||||||
|
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
|
||||||
|
'<div class="status-message">' +
|
||||||
|
'<strong>Authentication Successful!</strong><br>' +
|
||||||
|
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
|
||||||
|
'</div>';
|
||||||
|
}
|
||||||
|
|
||||||
|
function showError(data) {
|
||||||
|
const statusBox = document.getElementById('statusBox');
|
||||||
|
statusBox.className = 'status status-failed';
|
||||||
|
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
|
||||||
|
'<div class="status-message">' +
|
||||||
|
'<strong>Authentication Failed</strong><br>' +
|
||||||
|
(data.error || 'Unknown error') +
|
||||||
|
'</div>' +
|
||||||
|
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
|
||||||
|
'🔄 Retry' +
|
||||||
|
'</button>';
|
||||||
|
}
|
||||||
|
|
||||||
|
pollInterval = setInterval(pollStatus, 3000);
|
||||||
|
timerInterval = setInterval(updateTimer, 1000);
|
||||||
|
pollStatus();
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
oauthWebErrorPageHTML = `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Authentication Failed</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||||
|
max-width: 600px;
|
||||||
|
margin: 50px auto;
|
||||||
|
padding: 20px;
|
||||||
|
background: #f5f5f5;
|
||||||
|
}
|
||||||
|
.error {
|
||||||
|
background: #fff;
|
||||||
|
padding: 30px;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||||
|
border-left: 4px solid #dc3545;
|
||||||
|
}
|
||||||
|
h1 { color: #dc3545; margin-top: 0; }
|
||||||
|
.error-message { color: #666; line-height: 1.6; }
|
||||||
|
.retry-btn {
|
||||||
|
display: inline-block;
|
||||||
|
margin-top: 20px;
|
||||||
|
padding: 10px 20px;
|
||||||
|
background: #007bff;
|
||||||
|
color: white;
|
||||||
|
text-decoration: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
.retry-btn:hover { background: #0056b3; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="error">
|
||||||
|
<h1>❌ Authentication Failed</h1>
|
||||||
|
<div class="error-message">
|
||||||
|
<p><strong>Error:</strong></p>
|
||||||
|
<p>{{.Error}}</p>
|
||||||
|
</div>
|
||||||
|
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
oauthWebSuccessPageHTML = `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Authentication Successful</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||||
|
max-width: 600px;
|
||||||
|
margin: 50px auto;
|
||||||
|
padding: 20px;
|
||||||
|
background: #f5f5f5;
|
||||||
|
}
|
||||||
|
.success {
|
||||||
|
background: #fff;
|
||||||
|
padding: 30px;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||||
|
border-left: 4px solid #28a745;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
h1 { color: #28a745; margin-top: 0; }
|
||||||
|
.success-message { color: #666; line-height: 1.6; }
|
||||||
|
.icon { font-size: 48px; margin-bottom: 15px; }
|
||||||
|
.expires { font-size: 14px; color: #999; margin-top: 15px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="success">
|
||||||
|
<div class="icon">✅</div>
|
||||||
|
<h1>Authentication Successful!</h1>
|
||||||
|
<div class="success-message">
|
||||||
|
<p>You can close this window.</p>
|
||||||
|
</div>
|
||||||
|
<div class="expires">Token expires: {{.ExpiresAt}}</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
oauthWebSelectPageHTML = `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Select Authentication Method</title>
|
||||||
|
<style>
|
||||||
|
* { box-sizing: border-box; }
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||||
|
margin: 0;
|
||||||
|
padding: 20px;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
min-height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
max-width: 500px;
|
||||||
|
width: 100%;
|
||||||
|
background: #fff;
|
||||||
|
padding: 40px;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
margin: 0 0 10px;
|
||||||
|
color: #333;
|
||||||
|
font-size: 24px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.subtitle {
|
||||||
|
text-align: center;
|
||||||
|
color: #666;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
}
|
||||||
|
.auth-methods {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 15px;
|
||||||
|
}
|
||||||
|
.auth-btn {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
width: 100%;
|
||||||
|
padding: 15px 20px;
|
||||||
|
background: #667eea;
|
||||||
|
color: white;
|
||||||
|
text-decoration: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 16px;
|
||||||
|
transition: all 0.3s;
|
||||||
|
border: none;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
.auth-btn:hover {
|
||||||
|
background: #5568d3;
|
||||||
|
transform: translateY(-2px);
|
||||||
|
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||||
|
}
|
||||||
|
.auth-btn .icon {
|
||||||
|
font-size: 24px;
|
||||||
|
margin-right: 15px;
|
||||||
|
width: 32px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.auth-btn.google { background: #4285F4; }
|
||||||
|
.auth-btn.google:hover { background: #3367D6; }
|
||||||
|
.auth-btn.github { background: #24292e; }
|
||||||
|
.auth-btn.github:hover { background: #1a1e22; }
|
||||||
|
.auth-btn.aws { background: #FF9900; }
|
||||||
|
.auth-btn.aws:hover { background: #E68A00; }
|
||||||
|
.auth-btn.idc { background: #232F3E; }
|
||||||
|
.auth-btn.idc:hover { background: #1a242f; }
|
||||||
|
.idc-form {
|
||||||
|
background: #f8f9fa;
|
||||||
|
padding: 20px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-top: 15px;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
.idc-form.show {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
.form-group {
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
.form-group label {
|
||||||
|
display: block;
|
||||||
|
font-weight: 600;
|
||||||
|
color: #333;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
.form-group input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 12px;
|
||||||
|
border: 2px solid #e0e0e0;
|
||||||
|
border-radius: 6px;
|
||||||
|
font-size: 14px;
|
||||||
|
transition: border-color 0.3s;
|
||||||
|
}
|
||||||
|
.form-group input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #667eea;
|
||||||
|
}
|
||||||
|
.form-group .hint {
|
||||||
|
font-size: 12px;
|
||||||
|
color: #999;
|
||||||
|
margin-top: 5px;
|
||||||
|
}
|
||||||
|
.submit-btn {
|
||||||
|
display: block;
|
||||||
|
width: 100%;
|
||||||
|
padding: 15px;
|
||||||
|
background: #232F3E;
|
||||||
|
color: white;
|
||||||
|
text-align: center;
|
||||||
|
text-decoration: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 16px;
|
||||||
|
transition: all 0.3s;
|
||||||
|
border: none;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
.submit-btn:hover {
|
||||||
|
background: #1a242f;
|
||||||
|
transform: translateY(-2px);
|
||||||
|
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
|
||||||
|
}
|
||||||
|
.divider {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
margin: 20px 0;
|
||||||
|
}
|
||||||
|
.divider::before,
|
||||||
|
.divider::after {
|
||||||
|
content: "";
|
||||||
|
flex: 1;
|
||||||
|
border-bottom: 1px solid #e0e0e0;
|
||||||
|
}
|
||||||
|
.divider span {
|
||||||
|
padding: 0 15px;
|
||||||
|
color: #999;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
.info-box {
|
||||||
|
background: #e7f3ff;
|
||||||
|
border-left: 4px solid #2196F3;
|
||||||
|
padding: 15px;
|
||||||
|
margin-top: 20px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 14px;
|
||||||
|
color: #666;
|
||||||
|
}
|
||||||
|
.warning-box {
|
||||||
|
background: #fff3cd;
|
||||||
|
border-left: 4px solid #ffc107;
|
||||||
|
padding: 15px;
|
||||||
|
margin-top: 20px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 14px;
|
||||||
|
color: #856404;
|
||||||
|
}
|
||||||
|
.auth-btn.manual { background: #6c757d; }
|
||||||
|
.auth-btn.manual:hover { background: #5a6268; }
|
||||||
|
.auth-btn.refresh { background: #17a2b8; }
|
||||||
|
.auth-btn.refresh:hover { background: #138496; }
|
||||||
|
.auth-btn.refresh:disabled { background: #7fb3bd; cursor: not-allowed; }
|
||||||
|
.manual-form {
|
||||||
|
background: #f8f9fa;
|
||||||
|
padding: 20px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-top: 15px;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
.manual-form.show {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
.form-group textarea {
|
||||||
|
width: 100%;
|
||||||
|
padding: 12px;
|
||||||
|
border: 2px solid #e0e0e0;
|
||||||
|
border-radius: 6px;
|
||||||
|
font-size: 14px;
|
||||||
|
font-family: monospace;
|
||||||
|
transition: border-color 0.3s;
|
||||||
|
resize: vertical;
|
||||||
|
min-height: 80px;
|
||||||
|
}
|
||||||
|
.form-group textarea:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #667eea;
|
||||||
|
}
|
||||||
|
.status-message {
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 6px;
|
||||||
|
margin-top: 15px;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
.status-message.success {
|
||||||
|
background: #d4edda;
|
||||||
|
color: #155724;
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
.status-message.error {
|
||||||
|
background: #f8d7da;
|
||||||
|
color: #721c24;
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>🔐 Select Authentication Method</h1>
|
||||||
|
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
|
||||||
|
|
||||||
|
<div class="auth-methods">
|
||||||
|
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
|
||||||
|
<span class="icon">🔶</span>
|
||||||
|
AWS Builder ID (Recommended)
|
||||||
|
</a>
|
||||||
|
|
||||||
|
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
|
||||||
|
<span class="icon">🏢</span>
|
||||||
|
AWS Identity Center (IDC)
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div class="divider"><span>or</span></div>
|
||||||
|
|
||||||
|
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
|
||||||
|
<span class="icon">📋</span>
|
||||||
|
Import RefreshToken from Kiro IDE
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<button type="button" class="auth-btn refresh" onclick="manualRefresh()" id="refreshBtn">
|
||||||
|
<span class="icon">🔄</span>
|
||||||
|
Manual Refresh All Tokens
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div class="status-message" id="refreshStatus"></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="idc-form" id="idcForm">
|
||||||
|
<form action="/v0/oauth/kiro/start" method="get">
|
||||||
|
<input type="hidden" name="method" value="idc">
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="startUrl">Start URL</label>
|
||||||
|
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
|
||||||
|
<div class="hint">Your AWS Identity Center Start URL</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="region">Region</label>
|
||||||
|
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
|
||||||
|
<div class="hint">AWS Region for your Identity Center</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button type="submit" class="submit-btn">
|
||||||
|
🚀 Continue with IDC
|
||||||
|
</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="manual-form" id="manualForm">
|
||||||
|
<form id="importForm" onsubmit="submitImport(event)">
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="refreshToken">Refresh Token</label>
|
||||||
|
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
|
||||||
|
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button type="submit" class="submit-btn" id="importBtn">
|
||||||
|
📥 Import Token
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div class="status-message" id="importStatus"></div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="warning-box">
|
||||||
|
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="info-box">
|
||||||
|
💡 <strong>How to get RefreshToken:</strong><br>
|
||||||
|
1. Open Kiro IDE and login with Google/GitHub<br>
|
||||||
|
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
|
||||||
|
3. Copy the <code>refreshToken</code> value and paste it above
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
function toggleIdcForm() {
|
||||||
|
const idcForm = document.getElementById('idcForm');
|
||||||
|
const manualForm = document.getElementById('manualForm');
|
||||||
|
manualForm.classList.remove('show');
|
||||||
|
idcForm.classList.toggle('show');
|
||||||
|
if (idcForm.classList.contains('show')) {
|
||||||
|
document.getElementById('startUrl').focus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleManualForm() {
|
||||||
|
const idcForm = document.getElementById('idcForm');
|
||||||
|
const manualForm = document.getElementById('manualForm');
|
||||||
|
idcForm.classList.remove('show');
|
||||||
|
manualForm.classList.toggle('show');
|
||||||
|
if (manualForm.classList.contains('show')) {
|
||||||
|
document.getElementById('refreshToken').focus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function submitImport(event) {
|
||||||
|
event.preventDefault();
|
||||||
|
const refreshToken = document.getElementById('refreshToken').value.trim();
|
||||||
|
const statusEl = document.getElementById('importStatus');
|
||||||
|
const btn = document.getElementById('importBtn');
|
||||||
|
|
||||||
|
if (!refreshToken) {
|
||||||
|
statusEl.className = 'status-message error';
|
||||||
|
statusEl.textContent = 'Please enter a refresh token';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!refreshToken.startsWith('aorAAAAAG')) {
|
||||||
|
statusEl.className = 'status-message error';
|
||||||
|
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
btn.disabled = true;
|
||||||
|
btn.textContent = '⏳ Importing...';
|
||||||
|
statusEl.className = 'status-message';
|
||||||
|
statusEl.style.display = 'none';
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/v0/oauth/kiro/import', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ refreshToken: refreshToken })
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (response.ok && data.success) {
|
||||||
|
statusEl.className = 'status-message success';
|
||||||
|
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
|
||||||
|
} else {
|
||||||
|
statusEl.className = 'status-message error';
|
||||||
|
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
statusEl.className = 'status-message error';
|
||||||
|
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||||
|
} finally {
|
||||||
|
btn.disabled = false;
|
||||||
|
btn.textContent = '📥 Import Token';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function manualRefresh() {
|
||||||
|
const btn = document.getElementById('refreshBtn');
|
||||||
|
const statusEl = document.getElementById('refreshStatus');
|
||||||
|
|
||||||
|
btn.disabled = true;
|
||||||
|
btn.innerHTML = '<span class="icon">⏳</span> Refreshing...';
|
||||||
|
statusEl.className = 'status-message';
|
||||||
|
statusEl.style.display = 'none';
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/v0/oauth/kiro/refresh', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' }
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (response.ok && data.success) {
|
||||||
|
statusEl.className = 'status-message success';
|
||||||
|
let msg = '✅ ' + data.message;
|
||||||
|
if (data.warnings && data.warnings.length > 0) {
|
||||||
|
msg += ' (Warnings: ' + data.warnings.join('; ') + ')';
|
||||||
|
}
|
||||||
|
statusEl.textContent = msg;
|
||||||
|
} else {
|
||||||
|
statusEl.className = 'status-message error';
|
||||||
|
statusEl.textContent = '❌ ' + (data.error || data.message || 'Refresh failed');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
statusEl.className = 'status-message error';
|
||||||
|
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||||
|
} finally {
|
||||||
|
btn.disabled = false;
|
||||||
|
btn.innerHTML = '<span class="icon">🔄</span> Manual Refresh All Tokens';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
)
|
||||||
316
internal/auth/kiro/rate_limiter.go
Normal file
316
internal/auth/kiro/rate_limiter.go
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultMinTokenInterval = 10 * time.Second
|
||||||
|
DefaultMaxTokenInterval = 30 * time.Second
|
||||||
|
DefaultDailyMaxRequests = 500
|
||||||
|
DefaultJitterPercent = 0.3
|
||||||
|
DefaultBackoffBase = 2 * time.Minute
|
||||||
|
DefaultBackoffMax = 60 * time.Minute
|
||||||
|
DefaultBackoffMultiplier = 2.0
|
||||||
|
DefaultSuspendCooldown = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenState Token 状态
|
||||||
|
type TokenState struct {
|
||||||
|
LastRequest time.Time
|
||||||
|
RequestCount int
|
||||||
|
CooldownEnd time.Time
|
||||||
|
FailCount int
|
||||||
|
DailyRequests int
|
||||||
|
DailyResetTime time.Time
|
||||||
|
IsSuspended bool
|
||||||
|
SuspendedAt time.Time
|
||||||
|
SuspendReason string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimiter 频率限制器
|
||||||
|
type RateLimiter struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
states map[string]*TokenState
|
||||||
|
minTokenInterval time.Duration
|
||||||
|
maxTokenInterval time.Duration
|
||||||
|
dailyMaxRequests int
|
||||||
|
jitterPercent float64
|
||||||
|
backoffBase time.Duration
|
||||||
|
backoffMax time.Duration
|
||||||
|
backoffMultiplier float64
|
||||||
|
suspendCooldown time.Duration
|
||||||
|
rng *rand.Rand
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimiter 创建默认配置的频率限制器
|
||||||
|
func NewRateLimiter() *RateLimiter {
|
||||||
|
return &RateLimiter{
|
||||||
|
states: make(map[string]*TokenState),
|
||||||
|
minTokenInterval: DefaultMinTokenInterval,
|
||||||
|
maxTokenInterval: DefaultMaxTokenInterval,
|
||||||
|
dailyMaxRequests: DefaultDailyMaxRequests,
|
||||||
|
jitterPercent: DefaultJitterPercent,
|
||||||
|
backoffBase: DefaultBackoffBase,
|
||||||
|
backoffMax: DefaultBackoffMax,
|
||||||
|
backoffMultiplier: DefaultBackoffMultiplier,
|
||||||
|
suspendCooldown: DefaultSuspendCooldown,
|
||||||
|
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimiterConfig 频率限制器配置
|
||||||
|
type RateLimiterConfig struct {
|
||||||
|
MinTokenInterval time.Duration
|
||||||
|
MaxTokenInterval time.Duration
|
||||||
|
DailyMaxRequests int
|
||||||
|
JitterPercent float64
|
||||||
|
BackoffBase time.Duration
|
||||||
|
BackoffMax time.Duration
|
||||||
|
BackoffMultiplier float64
|
||||||
|
SuspendCooldown time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
|
||||||
|
func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
if cfg.MinTokenInterval > 0 {
|
||||||
|
rl.minTokenInterval = cfg.MinTokenInterval
|
||||||
|
}
|
||||||
|
if cfg.MaxTokenInterval > 0 {
|
||||||
|
rl.maxTokenInterval = cfg.MaxTokenInterval
|
||||||
|
}
|
||||||
|
if cfg.DailyMaxRequests > 0 {
|
||||||
|
rl.dailyMaxRequests = cfg.DailyMaxRequests
|
||||||
|
}
|
||||||
|
if cfg.JitterPercent > 0 {
|
||||||
|
rl.jitterPercent = cfg.JitterPercent
|
||||||
|
}
|
||||||
|
if cfg.BackoffBase > 0 {
|
||||||
|
rl.backoffBase = cfg.BackoffBase
|
||||||
|
}
|
||||||
|
if cfg.BackoffMax > 0 {
|
||||||
|
rl.backoffMax = cfg.BackoffMax
|
||||||
|
}
|
||||||
|
if cfg.BackoffMultiplier > 0 {
|
||||||
|
rl.backoffMultiplier = cfg.BackoffMultiplier
|
||||||
|
}
|
||||||
|
if cfg.SuspendCooldown > 0 {
|
||||||
|
rl.suspendCooldown = cfg.SuspendCooldown
|
||||||
|
}
|
||||||
|
return rl
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOrCreateState 获取或创建 Token 状态
|
||||||
|
func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
|
||||||
|
state, exists := rl.states[tokenKey]
|
||||||
|
if !exists {
|
||||||
|
state = &TokenState{
|
||||||
|
DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
|
||||||
|
}
|
||||||
|
rl.states[tokenKey] = state
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetDailyIfNeeded 如果需要则重置每日计数
|
||||||
|
func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
|
||||||
|
now := time.Now()
|
||||||
|
if now.After(state.DailyResetTime) {
|
||||||
|
state.DailyRequests = 0
|
||||||
|
state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateInterval 计算带抖动的随机间隔
|
||||||
|
func (rl *RateLimiter) calculateInterval() time.Duration {
|
||||||
|
baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
|
||||||
|
jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
|
||||||
|
return baseInterval + jitter
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForToken 等待 Token 可用(带抖动的随机间隔)
|
||||||
|
func (rl *RateLimiter) WaitForToken(tokenKey string) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
state := rl.getOrCreateState(tokenKey)
|
||||||
|
rl.resetDailyIfNeeded(state)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// 检查是否在冷却期
|
||||||
|
if now.Before(state.CooldownEnd) {
|
||||||
|
waitTime := state.CooldownEnd.Sub(now)
|
||||||
|
rl.mu.Unlock()
|
||||||
|
time.Sleep(waitTime)
|
||||||
|
rl.mu.Lock()
|
||||||
|
state = rl.getOrCreateState(tokenKey)
|
||||||
|
now = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算距离上次请求的间隔
|
||||||
|
interval := rl.calculateInterval()
|
||||||
|
nextAllowedTime := state.LastRequest.Add(interval)
|
||||||
|
|
||||||
|
if now.Before(nextAllowedTime) {
|
||||||
|
waitTime := nextAllowedTime.Sub(now)
|
||||||
|
rl.mu.Unlock()
|
||||||
|
time.Sleep(waitTime)
|
||||||
|
rl.mu.Lock()
|
||||||
|
state = rl.getOrCreateState(tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
state.LastRequest = time.Now()
|
||||||
|
state.RequestCount++
|
||||||
|
state.DailyRequests++
|
||||||
|
rl.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkTokenFailed 标记 Token 失败
|
||||||
|
func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
state := rl.getOrCreateState(tokenKey)
|
||||||
|
state.FailCount++
|
||||||
|
state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkTokenSuccess 标记 Token 成功
|
||||||
|
func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
state := rl.getOrCreateState(tokenKey)
|
||||||
|
state.FailCount = 0
|
||||||
|
state.CooldownEnd = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckAndMarkSuspended 检测暂停错误并标记
|
||||||
|
func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
|
||||||
|
suspendKeywords := []string{
|
||||||
|
"suspended",
|
||||||
|
"banned",
|
||||||
|
"disabled",
|
||||||
|
"account has been",
|
||||||
|
"access denied",
|
||||||
|
"rate limit exceeded",
|
||||||
|
"too many requests",
|
||||||
|
"quota exceeded",
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerMsg := strings.ToLower(errorMsg)
|
||||||
|
for _, keyword := range suspendKeywords {
|
||||||
|
if strings.Contains(lowerMsg, keyword) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
state := rl.getOrCreateState(tokenKey)
|
||||||
|
state.IsSuspended = true
|
||||||
|
state.SuspendedAt = time.Now()
|
||||||
|
state.SuspendReason = errorMsg
|
||||||
|
state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTokenAvailable 检查 Token 是否可用
|
||||||
|
func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
|
||||||
|
rl.mu.RLock()
|
||||||
|
defer rl.mu.RUnlock()
|
||||||
|
|
||||||
|
state, exists := rl.states[tokenKey]
|
||||||
|
if !exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// 检查是否被暂停
|
||||||
|
if state.IsSuspended {
|
||||||
|
if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否在冷却期
|
||||||
|
if now.Before(state.CooldownEnd) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查每日请求限制
|
||||||
|
rl.mu.RUnlock()
|
||||||
|
rl.mu.Lock()
|
||||||
|
rl.resetDailyIfNeeded(state)
|
||||||
|
dailyRequests := state.DailyRequests
|
||||||
|
dailyMax := rl.dailyMaxRequests
|
||||||
|
rl.mu.Unlock()
|
||||||
|
rl.mu.RLock()
|
||||||
|
|
||||||
|
if dailyRequests >= dailyMax {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateBackoff 计算指数退避时间
|
||||||
|
func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
|
||||||
|
if failCount <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
|
||||||
|
|
||||||
|
// 添加抖动
|
||||||
|
jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
|
||||||
|
backoff += jitter
|
||||||
|
|
||||||
|
if time.Duration(backoff) > rl.backoffMax {
|
||||||
|
return rl.backoffMax
|
||||||
|
}
|
||||||
|
return time.Duration(backoff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenState 获取 Token 状态(只读)
|
||||||
|
func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
|
||||||
|
rl.mu.RLock()
|
||||||
|
defer rl.mu.RUnlock()
|
||||||
|
|
||||||
|
state, exists := rl.states[tokenKey]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回副本以防止外部修改
|
||||||
|
stateCopy := *state
|
||||||
|
return &stateCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTokenState 清除 Token 状态
|
||||||
|
func (rl *RateLimiter) ClearTokenState(tokenKey string) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
delete(rl.states, tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetSuspension 重置暂停状态
|
||||||
|
func (rl *RateLimiter) ResetSuspension(tokenKey string) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
state, exists := rl.states[tokenKey]
|
||||||
|
if exists {
|
||||||
|
state.IsSuspended = false
|
||||||
|
state.SuspendedAt = time.Time{}
|
||||||
|
state.SuspendReason = ""
|
||||||
|
state.CooldownEnd = time.Time{}
|
||||||
|
state.FailCount = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
globalRateLimiter *RateLimiter
|
||||||
|
globalRateLimiterOnce sync.Once
|
||||||
|
|
||||||
|
globalCooldownManager *CooldownManager
|
||||||
|
globalCooldownManagerOnce sync.Once
|
||||||
|
cooldownStopCh chan struct{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetGlobalRateLimiter returns the singleton RateLimiter instance.
|
||||||
|
func GetGlobalRateLimiter() *RateLimiter {
|
||||||
|
globalRateLimiterOnce.Do(func() {
|
||||||
|
globalRateLimiter = NewRateLimiter()
|
||||||
|
log.Info("kiro: global RateLimiter initialized")
|
||||||
|
})
|
||||||
|
return globalRateLimiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalCooldownManager returns the singleton CooldownManager instance.
|
||||||
|
func GetGlobalCooldownManager() *CooldownManager {
|
||||||
|
globalCooldownManagerOnce.Do(func() {
|
||||||
|
globalCooldownManager = NewCooldownManager()
|
||||||
|
cooldownStopCh = make(chan struct{})
|
||||||
|
go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
|
||||||
|
log.Info("kiro: global CooldownManager initialized with cleanup routine")
|
||||||
|
})
|
||||||
|
return globalCooldownManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutdownRateLimiters stops the cooldown cleanup routine.
|
||||||
|
// Should be called during application shutdown.
|
||||||
|
func ShutdownRateLimiters() {
|
||||||
|
if cooldownStopCh != nil {
|
||||||
|
close(cooldownStopCh)
|
||||||
|
log.Info("kiro: rate limiter cleanup routine stopped")
|
||||||
|
}
|
||||||
|
}
|
||||||
304
internal/auth/kiro/rate_limiter_test.go
Normal file
304
internal/auth/kiro/rate_limiter_test.go
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewRateLimiter(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
if rl == nil {
|
||||||
|
t.Fatal("expected non-nil RateLimiter")
|
||||||
|
}
|
||||||
|
if rl.states == nil {
|
||||||
|
t.Error("expected non-nil states map")
|
||||||
|
}
|
||||||
|
if rl.minTokenInterval != DefaultMinTokenInterval {
|
||||||
|
t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
|
||||||
|
}
|
||||||
|
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||||
|
t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
|
||||||
|
}
|
||||||
|
if rl.dailyMaxRequests != DefaultDailyMaxRequests {
|
||||||
|
t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRateLimiterWithConfig(t *testing.T) {
|
||||||
|
cfg := RateLimiterConfig{
|
||||||
|
MinTokenInterval: 5 * time.Second,
|
||||||
|
MaxTokenInterval: 15 * time.Second,
|
||||||
|
DailyMaxRequests: 100,
|
||||||
|
JitterPercent: 0.2,
|
||||||
|
BackoffBase: 1 * time.Minute,
|
||||||
|
BackoffMax: 30 * time.Minute,
|
||||||
|
BackoffMultiplier: 1.5,
|
||||||
|
SuspendCooldown: 12 * time.Hour,
|
||||||
|
}
|
||||||
|
|
||||||
|
rl := NewRateLimiterWithConfig(cfg)
|
||||||
|
if rl.minTokenInterval != 5*time.Second {
|
||||||
|
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||||
|
}
|
||||||
|
if rl.maxTokenInterval != 15*time.Second {
|
||||||
|
t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
|
||||||
|
}
|
||||||
|
if rl.dailyMaxRequests != 100 {
|
||||||
|
t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
|
||||||
|
cfg := RateLimiterConfig{
|
||||||
|
MinTokenInterval: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
rl := NewRateLimiterWithConfig(cfg)
|
||||||
|
if rl.minTokenInterval != 5*time.Second {
|
||||||
|
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||||
|
}
|
||||||
|
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||||
|
t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTokenState_NonExistent(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
state := rl.GetTokenState("nonexistent")
|
||||||
|
if state != nil {
|
||||||
|
t.Error("expected nil state for non-existent token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTokenAvailable_NewToken(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
if !rl.IsTokenAvailable("newtoken") {
|
||||||
|
t.Error("expected new token to be available")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkTokenFailed(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
rl.MarkTokenFailed("token1")
|
||||||
|
|
||||||
|
state := rl.GetTokenState("token1")
|
||||||
|
if state == nil {
|
||||||
|
t.Fatal("expected non-nil state")
|
||||||
|
}
|
||||||
|
if state.FailCount != 1 {
|
||||||
|
t.Errorf("expected FailCount 1, got %d", state.FailCount)
|
||||||
|
}
|
||||||
|
if state.CooldownEnd.IsZero() {
|
||||||
|
t.Error("expected non-zero CooldownEnd")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkTokenSuccess(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
rl.MarkTokenFailed("token1")
|
||||||
|
rl.MarkTokenFailed("token1")
|
||||||
|
rl.MarkTokenSuccess("token1")
|
||||||
|
|
||||||
|
state := rl.GetTokenState("token1")
|
||||||
|
if state == nil {
|
||||||
|
t.Fatal("expected non-nil state")
|
||||||
|
}
|
||||||
|
if state.FailCount != 0 {
|
||||||
|
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||||
|
}
|
||||||
|
if !state.CooldownEnd.IsZero() {
|
||||||
|
t.Error("expected zero CooldownEnd after success")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
|
||||||
|
testCases := []string{
|
||||||
|
"Account has been suspended",
|
||||||
|
"You are banned from this service",
|
||||||
|
"Account disabled",
|
||||||
|
"Access denied permanently",
|
||||||
|
"Rate limit exceeded",
|
||||||
|
"Too many requests",
|
||||||
|
"Quota exceeded for today",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg := range testCases {
|
||||||
|
tokenKey := "token" + string(rune('a'+i))
|
||||||
|
if !rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||||
|
t.Errorf("expected suspension detected for: %s", msg)
|
||||||
|
}
|
||||||
|
state := rl.GetTokenState(tokenKey)
|
||||||
|
if !state.IsSuspended {
|
||||||
|
t.Errorf("expected IsSuspended true for: %s", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
|
||||||
|
normalErrors := []string{
|
||||||
|
"connection timeout",
|
||||||
|
"internal server error",
|
||||||
|
"bad request",
|
||||||
|
"invalid token format",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg := range normalErrors {
|
||||||
|
tokenKey := "token" + string(rune('a'+i))
|
||||||
|
if rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||||
|
t.Errorf("unexpected suspension for: %s", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTokenAvailable_Suspended(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||||
|
|
||||||
|
if rl.IsTokenAvailable("token1") {
|
||||||
|
t.Error("expected suspended token to be unavailable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearTokenState(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
rl.MarkTokenFailed("token1")
|
||||||
|
rl.ClearTokenState("token1")
|
||||||
|
|
||||||
|
state := rl.GetTokenState("token1")
|
||||||
|
if state != nil {
|
||||||
|
t.Error("expected nil state after clear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetSuspension(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||||
|
rl.ResetSuspension("token1")
|
||||||
|
|
||||||
|
state := rl.GetTokenState("token1")
|
||||||
|
if state.IsSuspended {
|
||||||
|
t.Error("expected IsSuspended false after reset")
|
||||||
|
}
|
||||||
|
if state.FailCount != 0 {
|
||||||
|
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetSuspension_NonExistent(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
rl.ResetSuspension("nonexistent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
backoff := rl.calculateBackoff(0)
|
||||||
|
if backoff != 0 {
|
||||||
|
t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateBackoff_Exponential(t *testing.T) {
|
||||||
|
cfg := RateLimiterConfig{
|
||||||
|
BackoffBase: 1 * time.Minute,
|
||||||
|
BackoffMax: 60 * time.Minute,
|
||||||
|
BackoffMultiplier: 2.0,
|
||||||
|
JitterPercent: 0.3,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiterWithConfig(cfg)
|
||||||
|
|
||||||
|
backoff1 := rl.calculateBackoff(1)
|
||||||
|
if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
|
||||||
|
t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
|
||||||
|
}
|
||||||
|
|
||||||
|
backoff2 := rl.calculateBackoff(2)
|
||||||
|
if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
|
||||||
|
t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateBackoff_MaxCap(t *testing.T) {
|
||||||
|
cfg := RateLimiterConfig{
|
||||||
|
BackoffBase: 1 * time.Minute,
|
||||||
|
BackoffMax: 10 * time.Minute,
|
||||||
|
BackoffMultiplier: 2.0,
|
||||||
|
JitterPercent: 0,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiterWithConfig(cfg)
|
||||||
|
|
||||||
|
backoff := rl.calculateBackoff(10)
|
||||||
|
if backoff > 10*time.Minute {
|
||||||
|
t.Errorf("expected backoff capped at 10min, got %v", backoff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTokenState_ReturnsCopy(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
rl.MarkTokenFailed("token1")
|
||||||
|
|
||||||
|
state1 := rl.GetTokenState("token1")
|
||||||
|
state1.FailCount = 999
|
||||||
|
|
||||||
|
state2 := rl.GetTokenState("token1")
|
||||||
|
if state2.FailCount == 999 {
|
||||||
|
t.Error("GetTokenState should return a copy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
|
||||||
|
rl := NewRateLimiter()
|
||||||
|
const numGoroutines = 50
|
||||||
|
const numOperations = 50
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
tokenKey := "token" + string(rune('a'+id%10))
|
||||||
|
for j := 0; j < numOperations; j++ {
|
||||||
|
switch j % 6 {
|
||||||
|
case 0:
|
||||||
|
rl.IsTokenAvailable(tokenKey)
|
||||||
|
case 1:
|
||||||
|
rl.MarkTokenFailed(tokenKey)
|
||||||
|
case 2:
|
||||||
|
rl.MarkTokenSuccess(tokenKey)
|
||||||
|
case 3:
|
||||||
|
rl.GetTokenState(tokenKey)
|
||||||
|
case 4:
|
||||||
|
rl.CheckAndMarkSuspended(tokenKey, "test error")
|
||||||
|
case 5:
|
||||||
|
rl.ResetSuspension(tokenKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateInterval_WithinRange(t *testing.T) {
|
||||||
|
cfg := RateLimiterConfig{
|
||||||
|
MinTokenInterval: 10 * time.Second,
|
||||||
|
MaxTokenInterval: 30 * time.Second,
|
||||||
|
JitterPercent: 0.3,
|
||||||
|
}
|
||||||
|
rl := NewRateLimiterWithConfig(cfg)
|
||||||
|
|
||||||
|
minAllowed := 7 * time.Second
|
||||||
|
maxAllowed := 40 * time.Second
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
interval := rl.calculateInterval()
|
||||||
|
if interval < minAllowed || interval > maxAllowed {
|
||||||
|
t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
180
internal/auth/kiro/refresh_manager.go
Normal file
180
internal/auth/kiro/refresh_manager.go
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RefreshManager 是后台刷新器的单例管理器
|
||||||
|
type RefreshManager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
refresher *BackgroundRefresher
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
started bool
|
||||||
|
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
globalRefreshManager *RefreshManager
|
||||||
|
managerOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetRefreshManager 获取全局刷新管理器实例
|
||||||
|
func GetRefreshManager() *RefreshManager {
|
||||||
|
managerOnce.Do(func() {
|
||||||
|
globalRefreshManager = &RefreshManager{}
|
||||||
|
})
|
||||||
|
return globalRefreshManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize 初始化后台刷新器
|
||||||
|
// baseDir: token 文件所在的目录
|
||||||
|
// cfg: 应用配置
|
||||||
|
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.started {
|
||||||
|
log.Debug("refresh manager: already initialized")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if baseDir == "" {
|
||||||
|
log.Warn("refresh manager: base directory not provided, skipping initialization")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedBaseDir, err := util.ResolveAuthDir(baseDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("refresh manager: failed to resolve auth directory %s: %v", baseDir, err)
|
||||||
|
}
|
||||||
|
if resolvedBaseDir != "" {
|
||||||
|
baseDir = resolvedBaseDir
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 token 存储库
|
||||||
|
repo := NewFileTokenRepository(baseDir)
|
||||||
|
|
||||||
|
// 创建后台刷新器,配置参数
|
||||||
|
opts := []RefresherOption{
|
||||||
|
WithInterval(time.Minute), // 每分钟检查一次
|
||||||
|
WithBatchSize(50), // 每批最多处理 50 个 token
|
||||||
|
WithConcurrency(10), // 最多 10 个并发刷新
|
||||||
|
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果已设置回调,传递给 BackgroundRefresher
|
||||||
|
if m.onTokenRefreshed != nil {
|
||||||
|
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
|
||||||
|
}
|
||||||
|
|
||||||
|
m.refresher = NewBackgroundRefresher(repo, opts...)
|
||||||
|
|
||||||
|
log.Infof("refresh manager: initialized with base directory %s", baseDir)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动后台刷新
|
||||||
|
func (m *RefreshManager) Start() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.started {
|
||||||
|
log.Debug("refresh manager: already started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.refresher == nil {
|
||||||
|
log.Warn("refresh manager: not initialized, cannot start")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||||
|
m.refresher.Start(m.ctx)
|
||||||
|
m.started = true
|
||||||
|
|
||||||
|
log.Info("refresh manager: background refresh started")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止后台刷新
|
||||||
|
func (m *RefreshManager) Stop() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if !m.started {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.cancel != nil {
|
||||||
|
m.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.refresher != nil {
|
||||||
|
m.refresher.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.started = false
|
||||||
|
log.Info("refresh manager: background refresh stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning 检查后台刷新是否正在运行
|
||||||
|
func (m *RefreshManager) IsRunning() bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.started
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
|
||||||
|
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.refresher != nil && m.refresher.tokenRepo != nil {
|
||||||
|
if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok {
|
||||||
|
repo.SetBaseDir(baseDir)
|
||||||
|
log.Infof("refresh manager: updated base directory to %s", baseDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
|
||||||
|
// 可以在任何时候调用,支持运行时更新回调
|
||||||
|
// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据
|
||||||
|
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.onTokenRefreshed = callback
|
||||||
|
|
||||||
|
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
|
||||||
|
if m.refresher != nil {
|
||||||
|
m.refresher.callbackMu.Lock()
|
||||||
|
m.refresher.onTokenRefreshed = callback
|
||||||
|
m.refresher.callbackMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("refresh manager: token refresh callback registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
|
||||||
|
func InitializeAndStart(baseDir string, cfg *config.Config) {
|
||||||
|
manager := GetRefreshManager()
|
||||||
|
if err := manager.Initialize(baseDir, cfg); err != nil {
|
||||||
|
log.Errorf("refresh manager: initialization failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
manager.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopGlobalRefreshManager 停止全局刷新管理器
|
||||||
|
func StopGlobalRefreshManager() {
|
||||||
|
if globalRefreshManager != nil {
|
||||||
|
globalRefreshManager.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,7 +9,9 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"html"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -31,6 +33,9 @@ const (
|
|||||||
|
|
||||||
// OAuth timeout
|
// OAuth timeout
|
||||||
socialAuthTimeout = 10 * time.Minute
|
socialAuthTimeout = 10 * time.Minute
|
||||||
|
|
||||||
|
// Default callback port for social auth HTTP server
|
||||||
|
socialAuthCallbackPort = 9876
|
||||||
)
|
)
|
||||||
|
|
||||||
// SocialProvider represents the social login provider.
|
// SocialProvider represents the social login provider.
|
||||||
@@ -67,6 +72,13 @@ type RefreshTokenRequest struct {
|
|||||||
RefreshToken string `json:"refreshToken"`
|
RefreshToken string `json:"refreshToken"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebCallbackResult contains the OAuth callback result from HTTP server.
|
||||||
|
type WebCallbackResult struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
// SocialAuthClient handles social authentication with Kiro.
|
// SocialAuthClient handles social authentication with Kiro.
|
||||||
type SocialAuthClient struct {
|
type SocialAuthClient struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
@@ -87,6 +99,83 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||||
|
// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
|
||||||
|
func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
|
||||||
|
// Try to find an available port - use localhost like Kiro does
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
|
||||||
|
if err != nil {
|
||||||
|
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||||
|
log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
|
||||||
|
listener, err = net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
// Use http scheme for local callback server
|
||||||
|
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||||
|
resultChan := make(chan WebCallbackResult, 1)
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
errParam := r.URL.Query().Get("error")
|
||||||
|
|
||||||
|
if errParam != "" {
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||||
|
<html><head><title>Login Failed</title></head>
|
||||||
|
<body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||||
|
resultChan <- WebCallbackResult{Error: errParam}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if state != expectedState {
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprint(w, `<!DOCTYPE html>
|
||||||
|
<html><head><title>Login Failed</title></head>
|
||||||
|
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||||
|
resultChan <- WebCallbackResult{Error: "state mismatch"}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
fmt.Fprint(w, `<!DOCTYPE html>
|
||||||
|
<html><head><title>Login Successful</title></head>
|
||||||
|
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||||
|
<script>window.close();</script></body></html>`)
|
||||||
|
resultChan <- WebCallbackResult{Code: code, State: state}
|
||||||
|
})
|
||||||
|
|
||||||
|
server.Handler = mux
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Debugf("kiro social auth callback server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(socialAuthTimeout):
|
||||||
|
case <-resultChan:
|
||||||
|
}
|
||||||
|
_ = server.Shutdown(context.Background())
|
||||||
|
}()
|
||||||
|
|
||||||
|
return redirectURI, resultChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
// generatePKCE generates PKCE code verifier and challenge.
|
// generatePKCE generates PKCE code verifier and challenge.
|
||||||
func generatePKCE() (verifier, challenge string, err error) {
|
func generatePKCE() (verifier, challenge string, err error) {
|
||||||
// Generate 32 bytes of random data for verifier
|
// Generate 32 bytes of random data for verifier
|
||||||
@@ -217,10 +306,12 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
|
|||||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
AuthMethod: "social",
|
AuthMethod: "social",
|
||||||
Provider: "", // Caller should preserve original provider
|
Provider: "", // Caller should preserve original provider
|
||||||
|
Region: "us-east-1",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginWithSocial performs OAuth login with Google.
|
// LoginWithSocial performs OAuth login with Google or GitHub.
|
||||||
|
// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
|
||||||
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
|
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
|
||||||
providerName := string(provider)
|
providerName := string(provider)
|
||||||
|
|
||||||
@@ -228,28 +319,10 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
|||||||
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
|
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
|
||||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
// Step 1: Setup protocol handler
|
// Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
|
||||||
|
// This avoids redirect_mismatch errors with AWS Cognito
|
||||||
fmt.Println("\nSetting up authentication...")
|
fmt.Println("\nSetting up authentication...")
|
||||||
|
|
||||||
// Start the local callback server
|
|
||||||
handlerPort, err := c.protocolHandler.Start(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
|
||||||
}
|
|
||||||
defer c.protocolHandler.Stop()
|
|
||||||
|
|
||||||
// Ensure protocol handler is installed and set as default
|
|
||||||
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
|
|
||||||
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
|
|
||||||
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
|
|
||||||
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
|
|
||||||
log.Debugf("kiro: protocol handler setup error: %v", err)
|
|
||||||
// Continue anyway - user might have set it up manually or select browser manually
|
|
||||||
} else {
|
|
||||||
// Force set our handler as default (prevents "Open with" dialog)
|
|
||||||
forceDefaultProtocolHandler()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 2: Generate PKCE codes
|
// Step 2: Generate PKCE codes
|
||||||
codeVerifier, codeChallenge, err := generatePKCE()
|
codeVerifier, codeChallenge, err := generatePKCE()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -262,8 +335,15 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
|||||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 4: Build the login URL (Kiro uses GET request with query params)
|
// Step 4: Start local HTTP callback server
|
||||||
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
|
redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("kiro social auth: callback server started at %s", redirectURI)
|
||||||
|
|
||||||
|
// Step 5: Build the login URL using HTTP redirect URI
|
||||||
|
authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
|
||||||
|
|
||||||
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||||
// Incognito mode enables multi-account support by bypassing cached sessions
|
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||||
@@ -279,7 +359,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
|||||||
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 5: Open browser for user authentication
|
// Step 6: Open browser for user authentication
|
||||||
fmt.Println("\n════════════════════════════════════════════════════════════")
|
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||||
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
|
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
|
||||||
fmt.Println("════════════════════════════════════════════════════════════")
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
@@ -295,80 +375,78 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
|||||||
|
|
||||||
fmt.Println("\n Waiting for authentication callback...")
|
fmt.Println("\n Waiting for authentication callback...")
|
||||||
|
|
||||||
// Step 6: Wait for callback
|
// Step 7: Wait for callback from HTTP server
|
||||||
callback, err := c.protocolHandler.WaitForCallback(ctx)
|
select {
|
||||||
if err != nil {
|
case <-ctx.Done():
|
||||||
return nil, fmt.Errorf("failed to receive callback: %w", err)
|
return nil, ctx.Err()
|
||||||
}
|
case <-time.After(socialAuthTimeout):
|
||||||
|
return nil, fmt.Errorf("authentication timed out")
|
||||||
if callback.Error != "" {
|
case callback := <-resultChan:
|
||||||
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
if callback.Error != "" {
|
||||||
}
|
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||||
|
|
||||||
if callback.State != state {
|
|
||||||
// Log state values for debugging, but don't expose in user-facing error
|
|
||||||
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
|
|
||||||
return nil, fmt.Errorf("OAuth state validation failed - please try again")
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback.Code == "" {
|
|
||||||
return nil, fmt.Errorf("no authorization code received")
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("\n✓ Authorization received!")
|
|
||||||
|
|
||||||
// Step 7: Exchange code for tokens
|
|
||||||
fmt.Println("Exchanging code for tokens...")
|
|
||||||
|
|
||||||
tokenReq := &CreateTokenRequest{
|
|
||||||
Code: callback.Code,
|
|
||||||
CodeVerifier: codeVerifier,
|
|
||||||
RedirectURI: KiroRedirectURI,
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("\n✓ Authentication successful!")
|
|
||||||
|
|
||||||
// Close the browser window
|
|
||||||
if err := browser.CloseBrowser(); err != nil {
|
|
||||||
log.Debugf("Failed to close browser: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate ExpiresIn - use default 1 hour if invalid
|
|
||||||
expiresIn := tokenResp.ExpiresIn
|
|
||||||
if expiresIn <= 0 {
|
|
||||||
expiresIn = 3600
|
|
||||||
}
|
|
||||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
|
||||||
|
|
||||||
// Try to extract email from JWT access token first
|
|
||||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
|
||||||
|
|
||||||
// If no email in JWT, ask user for account label (only in interactive mode)
|
|
||||||
if email == "" && isInteractiveTerminal() {
|
|
||||||
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
|
||||||
reader := bufio.NewReader(os.Stdin)
|
|
||||||
var err error
|
|
||||||
email, err = reader.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("Failed to read account label: %v", err)
|
|
||||||
}
|
}
|
||||||
email = strings.TrimSpace(email)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &KiroTokenData{
|
// State is already validated by the callback server
|
||||||
AccessToken: tokenResp.AccessToken,
|
if callback.Code == "" {
|
||||||
RefreshToken: tokenResp.RefreshToken,
|
return nil, fmt.Errorf("no authorization code received")
|
||||||
ProfileArn: tokenResp.ProfileArn,
|
}
|
||||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
|
||||||
AuthMethod: "social",
|
fmt.Println("\n✓ Authorization received!")
|
||||||
Provider: providerName,
|
|
||||||
Email: email, // JWT email or user-provided label
|
// Step 8: Exchange code for tokens
|
||||||
}, nil
|
fmt.Println("Exchanging code for tokens...")
|
||||||
|
|
||||||
|
tokenReq := &CreateTokenRequest{
|
||||||
|
Code: callback.Code,
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Authentication successful!")
|
||||||
|
|
||||||
|
// Close the browser window
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
// Try to extract email from JWT access token first
|
||||||
|
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
|
||||||
|
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||||
|
if email == "" && isInteractiveTerminal() {
|
||||||
|
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
var err error
|
||||||
|
email, err = reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to read account label: %v", err)
|
||||||
|
}
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: providerName,
|
||||||
|
Email: email, // JWT email or user-provided label
|
||||||
|
Region: "us-east-1",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginWithGoogle performs OAuth login with Google.
|
// LoginWithGoogle performs OAuth login with Google.
|
||||||
|
|||||||
@@ -735,6 +735,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
|||||||
Provider: "AWS",
|
Provider: "AWS",
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
|
Region: defaultIDCRegion,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -850,16 +851,17 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
|||||||
ClientID: regResp.ClientID,
|
ClientID: regResp.ClientID,
|
||||||
ClientSecret: regResp.ClientSecret,
|
ClientSecret: regResp.ClientSecret,
|
||||||
Email: email,
|
Email: email,
|
||||||
|
Region: defaultIDCRegion,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close browser on timeout for better UX
|
// Close browser on timeout for better UX
|
||||||
if err := browser.CloseBrowser(); err != nil {
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("authorization timed out")
|
return nil, fmt.Errorf("authorization timed out")
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
||||||
// Falls back to JWT parsing if userinfo fails.
|
// Falls back to JWT parsing if userinfo fails.
|
||||||
@@ -1366,6 +1368,7 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
|
|||||||
ClientID: regResp.ClientID,
|
ClientID: regResp.ClientID,
|
||||||
ClientSecret: regResp.ClientSecret,
|
ClientSecret: regResp.ClientSecret,
|
||||||
Email: email,
|
Email: email,
|
||||||
|
Region: defaultIDCRegion,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
|
|
||||||
// KiroTokenStorage holds the persistent token data for Kiro authentication.
|
// KiroTokenStorage holds the persistent token data for Kiro authentication.
|
||||||
type KiroTokenStorage struct {
|
type KiroTokenStorage struct {
|
||||||
|
// Type is the provider type for management UI recognition (must be "kiro")
|
||||||
|
Type string `json:"type"`
|
||||||
// AccessToken is the OAuth2 access token for API access
|
// AccessToken is the OAuth2 access token for API access
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
// RefreshToken is used to obtain new access tokens
|
// RefreshToken is used to obtain new access tokens
|
||||||
@@ -23,6 +25,16 @@ type KiroTokenStorage struct {
|
|||||||
Provider string `json:"provider"`
|
Provider string `json:"provider"`
|
||||||
// LastRefresh is the timestamp of the last token refresh
|
// LastRefresh is the timestamp of the last token refresh
|
||||||
LastRefresh string `json:"last_refresh"`
|
LastRefresh string `json:"last_refresh"`
|
||||||
|
// ClientID is the OAuth client ID (required for token refresh)
|
||||||
|
ClientID string `json:"client_id,omitempty"`
|
||||||
|
// ClientSecret is the OAuth client secret (required for token refresh)
|
||||||
|
ClientSecret string `json:"client_secret,omitempty"`
|
||||||
|
// Region is the AWS region
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
||||||
|
StartURL string `json:"start_url,omitempty"`
|
||||||
|
// Email is the user's email address
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile persists the token storage to the specified file path.
|
// SaveTokenToFile persists the token storage to the specified file path.
|
||||||
@@ -68,5 +80,10 @@ func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
|
|||||||
ExpiresAt: s.ExpiresAt,
|
ExpiresAt: s.ExpiresAt,
|
||||||
AuthMethod: s.AuthMethod,
|
AuthMethod: s.AuthMethod,
|
||||||
Provider: s.Provider,
|
Provider: s.Provider,
|
||||||
|
ClientID: s.ClientID,
|
||||||
|
ClientSecret: s.ClientSecret,
|
||||||
|
Region: s.Region,
|
||||||
|
StartURL: s.StartURL,
|
||||||
|
Email: s.Email,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
273
internal/auth/kiro/token_repository.go
Normal file
273
internal/auth/kiro/token_repository.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储
|
||||||
|
type FileTokenRepository struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
baseDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileTokenRepository 创建一个新的文件 token 存储库
|
||||||
|
func NewFileTokenRepository(baseDir string) *FileTokenRepository {
|
||||||
|
return &FileTokenRepository{
|
||||||
|
baseDir: baseDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBaseDir 设置基础目录
|
||||||
|
func (r *FileTokenRepository) SetBaseDir(dir string) {
|
||||||
|
r.mu.Lock()
|
||||||
|
r.baseDir = strings.TrimSpace(dir)
|
||||||
|
r.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindOldestUnverified 查找需要刷新的 token(按最后验证时间排序)
|
||||||
|
func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token {
|
||||||
|
r.mu.RLock()
|
||||||
|
baseDir := r.baseDir
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
if baseDir == "" {
|
||||||
|
log.Debug("token repository: base directory not configured")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens []*Token
|
||||||
|
|
||||||
|
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||||
|
if walkErr != nil {
|
||||||
|
return nil // 忽略错误,继续遍历
|
||||||
|
}
|
||||||
|
if d.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只处理 kiro 相关的 token 文件
|
||||||
|
if !strings.HasPrefix(d.Name(), "kiro-") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := r.readTokenFile(path)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("token repository: failed to read token file %s: %v", path, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if token != nil && token.RefreshToken != "" {
|
||||||
|
// 检查 token 是否需要刷新(过期前 5 分钟)
|
||||||
|
if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute {
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("token repository: error walking directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按最后验证时间排序(最旧的优先)
|
||||||
|
sort.Slice(tokens, func(i, j int) bool {
|
||||||
|
return tokens[i].LastVerified.Before(tokens[j].LastVerified)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 限制返回数量
|
||||||
|
if limit > 0 && len(tokens) > limit {
|
||||||
|
tokens = tokens[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateToken 更新 token 并持久化到文件
|
||||||
|
func (r *FileTokenRepository) UpdateToken(token *Token) error {
|
||||||
|
if token == nil {
|
||||||
|
return fmt.Errorf("token repository: token is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.RLock()
|
||||||
|
baseDir := r.baseDir
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
if baseDir == "" {
|
||||||
|
return fmt.Errorf("token repository: base directory not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建文件路径
|
||||||
|
filePath := filepath.Join(baseDir, token.ID)
|
||||||
|
if !strings.HasSuffix(filePath, ".json") {
|
||||||
|
filePath += ".json"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取现有文件内容
|
||||||
|
existingData := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(filePath); err == nil {
|
||||||
|
_ = json.Unmarshal(data, &existingData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新字段
|
||||||
|
existingData["access_token"] = token.AccessToken
|
||||||
|
existingData["refresh_token"] = token.RefreshToken
|
||||||
|
existingData["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||||
|
|
||||||
|
if !token.ExpiresAt.IsZero() {
|
||||||
|
existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保持原有的关键字段
|
||||||
|
if token.ClientID != "" {
|
||||||
|
existingData["client_id"] = token.ClientID
|
||||||
|
}
|
||||||
|
if token.ClientSecret != "" {
|
||||||
|
existingData["client_secret"] = token.ClientSecret
|
||||||
|
}
|
||||||
|
if token.AuthMethod != "" {
|
||||||
|
existingData["auth_method"] = token.AuthMethod
|
||||||
|
}
|
||||||
|
if token.Region != "" {
|
||||||
|
existingData["region"] = token.Region
|
||||||
|
}
|
||||||
|
if token.StartURL != "" {
|
||||||
|
existingData["start_url"] = token.StartURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// 序列化并写入文件
|
||||||
|
raw, err := json.MarshalIndent(existingData, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("token repository: marshal failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原子写入:先写入临时文件,再重命名
|
||||||
|
tmpPath := filePath + ".tmp"
|
||||||
|
if err := os.WriteFile(tmpPath, raw, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("token repository: write temp file failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.Rename(tmpPath, filePath); err != nil {
|
||||||
|
_ = os.Remove(tmpPath)
|
||||||
|
return fmt.Errorf("token repository: rename failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("token repository: updated token %s", token.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readTokenFile 从文件读取 token
|
||||||
|
func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var metadata map[string]any
|
||||||
|
if err := json.Unmarshal(data, &metadata); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是 kiro token
|
||||||
|
tokenType, _ := metadata["type"].(string)
|
||||||
|
if tokenType != "kiro" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 auth_method
|
||||||
|
authMethod, _ := metadata["auth_method"].(string)
|
||||||
|
if authMethod != "idc" && authMethod != "builder-id" {
|
||||||
|
return nil, nil // 只处理 IDC 和 Builder ID token
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &Token{
|
||||||
|
ID: filepath.Base(path),
|
||||||
|
AuthMethod: authMethod,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析各字段
|
||||||
|
if v, ok := metadata["access_token"].(string); ok {
|
||||||
|
token.AccessToken = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["refresh_token"].(string); ok {
|
||||||
|
token.RefreshToken = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["client_id"].(string); ok {
|
||||||
|
token.ClientID = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["client_secret"].(string); ok {
|
||||||
|
token.ClientSecret = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["region"].(string); ok {
|
||||||
|
token.Region = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["start_url"].(string); ok {
|
||||||
|
token.StartURL = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["provider"].(string); ok {
|
||||||
|
token.Provider = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析时间字段
|
||||||
|
if v, ok := metadata["expires_at"].(string); ok {
|
||||||
|
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||||
|
token.ExpiresAt = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := metadata["last_refresh"].(string); ok {
|
||||||
|
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||||
|
token.LastVerified = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListKiroTokens 列出所有 Kiro token(用于调试)
|
||||||
|
func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
baseDir := r.baseDir
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
if baseDir == "" {
|
||||||
|
return nil, fmt.Errorf("token repository: base directory not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens []*Token
|
||||||
|
|
||||||
|
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||||
|
if walkErr != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if d.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := r.readTokenFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if token != nil {
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
243
internal/auth/kiro/usage_checker.go
Normal file
243
internal/auth/kiro/usage_checker.go
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||||
|
// This file implements usage quota checking and monitoring.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UsageQuotaResponse represents the API response structure for usage quota checking.
|
||||||
|
type UsageQuotaResponse struct {
|
||||||
|
UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
|
||||||
|
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||||
|
NextDateReset float64 `json:"nextDateReset,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageBreakdownExtended represents detailed usage information for quota checking.
|
||||||
|
// Note: UsageBreakdown is already defined in codewhisperer_client.go
|
||||||
|
type UsageBreakdownExtended struct {
|
||||||
|
ResourceType string `json:"resourceType"`
|
||||||
|
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||||
|
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||||
|
FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FreeTrialInfoExtended represents free trial usage information.
|
||||||
|
type FreeTrialInfoExtended struct {
|
||||||
|
FreeTrialStatus string `json:"freeTrialStatus"`
|
||||||
|
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||||
|
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuotaStatus represents the quota status for a token.
|
||||||
|
type QuotaStatus struct {
|
||||||
|
TotalLimit float64
|
||||||
|
CurrentUsage float64
|
||||||
|
RemainingQuota float64
|
||||||
|
IsExhausted bool
|
||||||
|
ResourceType string
|
||||||
|
NextReset time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageChecker provides methods for checking token quota usage.
|
||||||
|
type UsageChecker struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUsageChecker creates a new UsageChecker instance.
|
||||||
|
func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
||||||
|
return &UsageChecker{
|
||||||
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||||
|
endpoint: awsKiroEndpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
|
||||||
|
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
|
||||||
|
return &UsageChecker{
|
||||||
|
httpClient: client,
|
||||||
|
endpoint: awsKiroEndpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUsage retrieves usage limits for the given token.
|
||||||
|
func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
|
||||||
|
if tokenData == nil {
|
||||||
|
return nil, fmt.Errorf("token data is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenData.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("access token is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
"profileArn": tokenData.ProfileArn,
|
||||||
|
"resourceType": "AGENTIC_REQUEST",
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBody, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||||
|
req.Header.Set("x-amz-target", targetGetUsage)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result UsageQuotaResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
|
||||||
|
func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
|
||||||
|
tokenData := &KiroTokenData{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
ProfileArn: profileArn,
|
||||||
|
}
|
||||||
|
return c.CheckUsage(ctx, tokenData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRemainingQuota calculates the remaining quota from usage limits.
|
||||||
|
func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
|
||||||
|
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalRemaining float64
|
||||||
|
for _, breakdown := range usage.UsageBreakdownList {
|
||||||
|
remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||||
|
if remaining > 0 {
|
||||||
|
totalRemaining += remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
if breakdown.FreeTrialInfo != nil {
|
||||||
|
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||||
|
if freeRemaining > 0 {
|
||||||
|
totalRemaining += freeRemaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalRemaining
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
|
||||||
|
func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
|
||||||
|
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, breakdown := range usage.UsageBreakdownList {
|
||||||
|
if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if breakdown.FreeTrialInfo != nil {
|
||||||
|
if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaStatus retrieves a comprehensive quota status for a token.
|
||||||
|
func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
|
||||||
|
usage, err := c.CheckUsage(ctx, tokenData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
status := &QuotaStatus{
|
||||||
|
IsExhausted: IsQuotaExhausted(usage),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(usage.UsageBreakdownList) > 0 {
|
||||||
|
breakdown := usage.UsageBreakdownList[0]
|
||||||
|
status.TotalLimit = breakdown.UsageLimitWithPrecision
|
||||||
|
status.CurrentUsage = breakdown.CurrentUsageWithPrecision
|
||||||
|
status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||||
|
status.ResourceType = breakdown.ResourceType
|
||||||
|
|
||||||
|
if breakdown.FreeTrialInfo != nil {
|
||||||
|
status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||||
|
status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||||
|
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||||
|
if freeRemaining > 0 {
|
||||||
|
status.RemainingQuota += freeRemaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.NextDateReset > 0 {
|
||||||
|
status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateAvailableCount calculates the available request count based on usage limits.
|
||||||
|
func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
|
||||||
|
return GetRemainingQuota(usage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsagePercentage calculates the usage percentage.
|
||||||
|
func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
|
||||||
|
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||||
|
return 100.0
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalLimit, totalUsage float64
|
||||||
|
for _, breakdown := range usage.UsageBreakdownList {
|
||||||
|
totalLimit += breakdown.UsageLimitWithPrecision
|
||||||
|
totalUsage += breakdown.CurrentUsageWithPrecision
|
||||||
|
|
||||||
|
if breakdown.FreeTrialInfo != nil {
|
||||||
|
totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||||
|
totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalLimit == 0 {
|
||||||
|
return 100.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return (totalUsage / totalLimit) * 100
|
||||||
|
}
|
||||||
114
internal/cache/signature_cache.go
vendored
114
internal/cache/signature_cache.go
vendored
@@ -3,6 +3,7 @@ package cache
|
|||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -23,18 +24,18 @@ const (
|
|||||||
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
||||||
MinValidSignatureLen = 50
|
MinValidSignatureLen = 50
|
||||||
|
|
||||||
// SessionCleanupInterval controls how often stale sessions are purged
|
// CacheCleanupInterval controls how often stale entries are purged
|
||||||
SessionCleanupInterval = 10 * time.Minute
|
CacheCleanupInterval = 10 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
|
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
|
||||||
var signatureCache sync.Map
|
var signatureCache sync.Map
|
||||||
|
|
||||||
// sessionCleanupOnce ensures the background cleanup goroutine starts only once
|
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
|
||||||
var sessionCleanupOnce sync.Once
|
var cacheCleanupOnce sync.Once
|
||||||
|
|
||||||
// sessionCache is the inner map type
|
// groupCache is the inner map type
|
||||||
type sessionCache struct {
|
type groupCache struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
entries map[string]SignatureEntry
|
entries map[string]SignatureEntry
|
||||||
}
|
}
|
||||||
@@ -45,36 +46,36 @@ func hashText(text string) string {
|
|||||||
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrCreateSession gets or creates a session cache
|
// getOrCreateGroupCache gets or creates a cache bucket for a model group
|
||||||
func getOrCreateSession(sessionID string) *sessionCache {
|
func getOrCreateGroupCache(groupKey string) *groupCache {
|
||||||
// Start background cleanup on first access
|
// Start background cleanup on first access
|
||||||
sessionCleanupOnce.Do(startSessionCleanup)
|
cacheCleanupOnce.Do(startCacheCleanup)
|
||||||
|
|
||||||
if val, ok := signatureCache.Load(sessionID); ok {
|
if val, ok := signatureCache.Load(groupKey); ok {
|
||||||
return val.(*sessionCache)
|
return val.(*groupCache)
|
||||||
}
|
}
|
||||||
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
|
sc := &groupCache{entries: make(map[string]SignatureEntry)}
|
||||||
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
|
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
|
||||||
return actual.(*sessionCache)
|
return actual.(*groupCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
// startSessionCleanup launches a background goroutine that periodically
|
// startCacheCleanup launches a background goroutine that periodically
|
||||||
// removes sessions where all entries have expired.
|
// removes caches where all entries have expired.
|
||||||
func startSessionCleanup() {
|
func startCacheCleanup() {
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(SessionCleanupInterval)
|
ticker := time.NewTicker(CacheCleanupInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
purgeExpiredSessions()
|
purgeExpiredCaches()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// purgeExpiredSessions removes sessions with no valid (non-expired) entries.
|
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
|
||||||
func purgeExpiredSessions() {
|
func purgeExpiredCaches() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
signatureCache.Range(func(key, value any) bool {
|
signatureCache.Range(func(key, value any) bool {
|
||||||
sc := value.(*sessionCache)
|
sc := value.(*groupCache)
|
||||||
sc.mu.Lock()
|
sc.mu.Lock()
|
||||||
// Remove expired entries
|
// Remove expired entries
|
||||||
for k, entry := range sc.entries {
|
for k, entry := range sc.entries {
|
||||||
@@ -84,7 +85,7 @@ func purgeExpiredSessions() {
|
|||||||
}
|
}
|
||||||
isEmpty := len(sc.entries) == 0
|
isEmpty := len(sc.entries) == 0
|
||||||
sc.mu.Unlock()
|
sc.mu.Unlock()
|
||||||
// Remove session if empty
|
// Remove cache bucket if empty
|
||||||
if isEmpty {
|
if isEmpty {
|
||||||
signatureCache.Delete(key)
|
signatureCache.Delete(key)
|
||||||
}
|
}
|
||||||
@@ -92,19 +93,19 @@ func purgeExpiredSessions() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CacheSignature stores a thinking signature for a given session and text.
|
// CacheSignature stores a thinking signature for a given model group and text.
|
||||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||||
func CacheSignature(sessionID, text, signature string) {
|
func CacheSignature(modelName, text, signature string) {
|
||||||
if sessionID == "" || text == "" || signature == "" {
|
if text == "" || signature == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(signature) < MinValidSignatureLen {
|
if len(signature) < MinValidSignatureLen {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sc := getOrCreateSession(sessionID)
|
groupKey := GetModelGroup(modelName)
|
||||||
textHash := hashText(text)
|
textHash := hashText(text)
|
||||||
|
sc := getOrCreateGroupCache(groupKey)
|
||||||
sc.mu.Lock()
|
sc.mu.Lock()
|
||||||
defer sc.mu.Unlock()
|
defer sc.mu.Unlock()
|
||||||
|
|
||||||
@@ -114,18 +115,25 @@ func CacheSignature(sessionID, text, signature string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCachedSignature retrieves a cached signature for a given session and text.
|
// GetCachedSignature retrieves a cached signature for a given model group and text.
|
||||||
// Returns empty string if not found or expired.
|
// Returns empty string if not found or expired.
|
||||||
func GetCachedSignature(sessionID, text string) string {
|
func GetCachedSignature(modelName, text string) string {
|
||||||
if sessionID == "" || text == "" {
|
groupKey := GetModelGroup(modelName)
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
val, ok := signatureCache.Load(sessionID)
|
if text == "" {
|
||||||
if !ok {
|
if groupKey == "gemini" {
|
||||||
|
return "skip_thought_signature_validator"
|
||||||
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
sc := val.(*sessionCache)
|
val, ok := signatureCache.Load(groupKey)
|
||||||
|
if !ok {
|
||||||
|
if groupKey == "gemini" {
|
||||||
|
return "skip_thought_signature_validator"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sc := val.(*groupCache)
|
||||||
|
|
||||||
textHash := hashText(text)
|
textHash := hashText(text)
|
||||||
|
|
||||||
@@ -135,11 +143,17 @@ func GetCachedSignature(sessionID, text string) string {
|
|||||||
entry, exists := sc.entries[textHash]
|
entry, exists := sc.entries[textHash]
|
||||||
if !exists {
|
if !exists {
|
||||||
sc.mu.Unlock()
|
sc.mu.Unlock()
|
||||||
|
if groupKey == "gemini" {
|
||||||
|
return "skip_thought_signature_validator"
|
||||||
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||||
delete(sc.entries, textHash)
|
delete(sc.entries, textHash)
|
||||||
sc.mu.Unlock()
|
sc.mu.Unlock()
|
||||||
|
if groupKey == "gemini" {
|
||||||
|
return "skip_thought_signature_validator"
|
||||||
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,19 +165,31 @@ func GetCachedSignature(sessionID, text string) string {
|
|||||||
return entry.Signature
|
return entry.Signature
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearSignatureCache clears signature cache for a specific session or all sessions.
|
// ClearSignatureCache clears signature cache for a specific model group or all groups.
|
||||||
func ClearSignatureCache(sessionID string) {
|
func ClearSignatureCache(modelName string) {
|
||||||
if sessionID != "" {
|
if modelName == "" {
|
||||||
signatureCache.Delete(sessionID)
|
|
||||||
} else {
|
|
||||||
signatureCache.Range(func(key, _ any) bool {
|
signatureCache.Range(func(key, _ any) bool {
|
||||||
signatureCache.Delete(key)
|
signatureCache.Delete(key)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
groupKey := GetModelGroup(modelName)
|
||||||
|
signatureCache.Delete(groupKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
||||||
func HasValidSignature(signature string) bool {
|
func HasValidSignature(modelName, signature string) bool {
|
||||||
return signature != "" && len(signature) >= MinValidSignatureLen
|
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModelGroup(modelName string) string {
|
||||||
|
if strings.Contains(modelName, "gpt") {
|
||||||
|
return "gpt"
|
||||||
|
} else if strings.Contains(modelName, "claude") {
|
||||||
|
return "claude"
|
||||||
|
} else if strings.Contains(modelName, "gemini") {
|
||||||
|
return "gemini"
|
||||||
|
}
|
||||||
|
return modelName
|
||||||
}
|
}
|
||||||
|
|||||||
110
internal/cache/signature_cache_test.go
vendored
110
internal/cache/signature_cache_test.go
vendored
@@ -5,38 +5,40 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const testModelName = "claude-sonnet-4-5"
|
||||||
|
|
||||||
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "test-session-1"
|
|
||||||
text := "This is some thinking text content"
|
text := "This is some thinking text content"
|
||||||
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
|
||||||
// Store signature
|
// Store signature
|
||||||
CacheSignature(sessionID, text, signature)
|
CacheSignature(testModelName, text, signature)
|
||||||
|
|
||||||
// Retrieve signature
|
// Retrieve signature
|
||||||
retrieved := GetCachedSignature(sessionID, text)
|
retrieved := GetCachedSignature(testModelName, text)
|
||||||
if retrieved != signature {
|
if retrieved != signature {
|
||||||
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCacheSignature_DifferentSessions(t *testing.T) {
|
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
text := "Same text in different sessions"
|
text := "Same text across models"
|
||||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||||
|
|
||||||
CacheSignature("session-a", text, sig1)
|
geminiModel := "gemini-3-pro-preview"
|
||||||
CacheSignature("session-b", text, sig2)
|
CacheSignature(testModelName, text, sig1)
|
||||||
|
CacheSignature(geminiModel, text, sig2)
|
||||||
|
|
||||||
if GetCachedSignature("session-a", text) != sig1 {
|
if GetCachedSignature(testModelName, text) != sig1 {
|
||||||
t.Error("Session-a signature mismatch")
|
t.Error("Claude signature mismatch")
|
||||||
}
|
}
|
||||||
if GetCachedSignature("session-b", text) != sig2 {
|
if GetCachedSignature(geminiModel, text) != sig2 {
|
||||||
t.Error("Session-b signature mismatch")
|
t.Error("Gemini signature mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,13 +46,13 @@ func TestCacheSignature_NotFound(t *testing.T) {
|
|||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
// Non-existent session
|
// Non-existent session
|
||||||
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
|
if got := GetCachedSignature(testModelName, "some text"); got != "" {
|
||||||
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Existing session but different text
|
// Existing session but different text
|
||||||
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
CacheSignature(testModelName, "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||||
if got := GetCachedSignature("session-x", "text-b"); got != "" {
|
if got := GetCachedSignature(testModelName, "text-b"); got != "" {
|
||||||
t.Errorf("Expected empty string for different text, got '%s'", got)
|
t.Errorf("Expected empty string for different text, got '%s'", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -59,12 +61,11 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
|||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
// All empty/invalid inputs should be no-ops
|
// All empty/invalid inputs should be no-ops
|
||||||
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
|
CacheSignature(testModelName, "", "sig12345678901234567890123456789012345678901234567890")
|
||||||
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
|
CacheSignature(testModelName, "text", "")
|
||||||
CacheSignature("session", "text", "")
|
CacheSignature(testModelName, "text", "short") // Too short
|
||||||
CacheSignature("session", "text", "short") // Too short
|
|
||||||
|
|
||||||
if got := GetCachedSignature("session", "text"); got != "" {
|
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||||
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -72,31 +73,27 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
|||||||
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "test-short-sig"
|
|
||||||
text := "Some text"
|
text := "Some text"
|
||||||
shortSig := "abc123" // Less than 50 chars
|
shortSig := "abc123" // Less than 50 chars
|
||||||
|
|
||||||
CacheSignature(sessionID, text, shortSig)
|
CacheSignature(testModelName, text, shortSig)
|
||||||
|
|
||||||
if got := GetCachedSignature(sessionID, text); got != "" {
|
if got := GetCachedSignature(testModelName, text); got != "" {
|
||||||
t.Errorf("Short signature should be rejected, got '%s'", got)
|
t.Errorf("Short signature should be rejected, got '%s'", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClearSignatureCache_SpecificSession(t *testing.T) {
|
func TestClearSignatureCache_ModelGroup(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||||
CacheSignature("session-1", "text", sig)
|
CacheSignature(testModelName, "text", sig)
|
||||||
CacheSignature("session-2", "text", sig)
|
CacheSignature(testModelName, "text-2", sig)
|
||||||
|
|
||||||
ClearSignatureCache("session-1")
|
ClearSignatureCache("session-1")
|
||||||
|
|
||||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
if got := GetCachedSignature(testModelName, "text"); got != sig {
|
||||||
t.Error("session-1 should be cleared")
|
t.Error("signature should remain when clearing unknown session")
|
||||||
}
|
|
||||||
if got := GetCachedSignature("session-2", "text"); got != sig {
|
|
||||||
t.Error("session-2 should still exist")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,35 +101,37 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
|
|||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||||
CacheSignature("session-1", "text", sig)
|
CacheSignature(testModelName, "text", sig)
|
||||||
CacheSignature("session-2", "text", sig)
|
CacheSignature(testModelName, "text-2", sig)
|
||||||
|
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||||
t.Error("session-1 should be cleared")
|
t.Error("text should be cleared")
|
||||||
}
|
}
|
||||||
if got := GetCachedSignature("session-2", "text"); got != "" {
|
if got := GetCachedSignature(testModelName, "text-2"); got != "" {
|
||||||
t.Error("session-2 should be cleared")
|
t.Error("text-2 should be cleared")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHasValidSignature(t *testing.T) {
|
func TestHasValidSignature(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
modelName string
|
||||||
signature string
|
signature string
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
|
{"valid long signature", testModelName, "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||||
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
|
{"exactly 50 chars", testModelName, "12345678901234567890123456789012345678901234567890", true},
|
||||||
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
|
{"49 chars - invalid", testModelName, "1234567890123456789012345678901234567890123456789", false},
|
||||||
{"empty string", "", false},
|
{"empty string", testModelName, "", false},
|
||||||
{"short signature", "abc", false},
|
{"short signature", testModelName, "abc", false},
|
||||||
|
{"gemini sentinel", "gemini-3-pro-preview", "skip_thought_signature_validator", true},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := HasValidSignature(tt.signature)
|
result := HasValidSignature(tt.modelName, tt.signature)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
||||||
}
|
}
|
||||||
@@ -143,21 +142,19 @@ func TestHasValidSignature(t *testing.T) {
|
|||||||
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "hash-test-session"
|
|
||||||
|
|
||||||
// Different texts should produce different hashes
|
// Different texts should produce different hashes
|
||||||
text1 := "First thinking text"
|
text1 := "First thinking text"
|
||||||
text2 := "Second thinking text"
|
text2 := "Second thinking text"
|
||||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||||
|
|
||||||
CacheSignature(sessionID, text1, sig1)
|
CacheSignature(testModelName, text1, sig1)
|
||||||
CacheSignature(sessionID, text2, sig2)
|
CacheSignature(testModelName, text2, sig2)
|
||||||
|
|
||||||
if GetCachedSignature(sessionID, text1) != sig1 {
|
if GetCachedSignature(testModelName, text1) != sig1 {
|
||||||
t.Error("text1 signature mismatch")
|
t.Error("text1 signature mismatch")
|
||||||
}
|
}
|
||||||
if GetCachedSignature(sessionID, text2) != sig2 {
|
if GetCachedSignature(testModelName, text2) != sig2 {
|
||||||
t.Error("text2 signature mismatch")
|
t.Error("text2 signature mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,13 +162,12 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
|||||||
func TestCacheSignature_UnicodeText(t *testing.T) {
|
func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "unicode-session"
|
|
||||||
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
||||||
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
||||||
|
|
||||||
CacheSignature(sessionID, text, sig)
|
CacheSignature(testModelName, text, sig)
|
||||||
|
|
||||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||||
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -179,15 +175,14 @@ func TestCacheSignature_UnicodeText(t *testing.T) {
|
|||||||
func TestCacheSignature_Overwrite(t *testing.T) {
|
func TestCacheSignature_Overwrite(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "overwrite-session"
|
|
||||||
text := "Same text"
|
text := "Same text"
|
||||||
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
||||||
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
||||||
|
|
||||||
CacheSignature(sessionID, text, sig1)
|
CacheSignature(testModelName, text, sig1)
|
||||||
CacheSignature(sessionID, text, sig2) // Overwrite
|
CacheSignature(testModelName, text, sig2) // Overwrite
|
||||||
|
|
||||||
if got := GetCachedSignature(sessionID, text); got != sig2 {
|
if got := GetCachedSignature(testModelName, text); got != sig2 {
|
||||||
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,14 +194,13 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
|||||||
|
|
||||||
// This test verifies the expiration check exists
|
// This test verifies the expiration check exists
|
||||||
// In a real scenario, we'd mock time.Now()
|
// In a real scenario, we'd mock time.Now()
|
||||||
sessionID := "expiration-test"
|
|
||||||
text := "text"
|
text := "text"
|
||||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||||
|
|
||||||
CacheSignature(sessionID, text, sig)
|
CacheSignature(testModelName, text, sig)
|
||||||
|
|
||||||
// Fresh entry should be retrievable
|
// Fresh entry should be retrievable
|
||||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||||
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
activatedProjects := make([]string, 0, len(projectSelections))
|
activatedProjects := make([]string, 0, len(projectSelections))
|
||||||
|
seenProjects := make(map[string]bool)
|
||||||
for _, candidateID := range projectSelections {
|
for _, candidateID := range projectSelections {
|
||||||
log.Infof("Activating project %s", candidateID)
|
log.Infof("Activating project %s", candidateID)
|
||||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||||
@@ -134,6 +135,13 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
if finalID == "" {
|
if finalID == "" {
|
||||||
finalID = candidateID
|
finalID = candidateID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip duplicates
|
||||||
|
if seenProjects[finalID] {
|
||||||
|
log.Infof("Project %s already activated, skipping", finalID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenProjects[finalID] = true
|
||||||
activatedProjects = append(activatedProjects, finalID)
|
activatedProjects = append(activatedProjects, finalID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,8 +269,39 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
|||||||
finalProjectID := projectID
|
finalProjectID := projectID
|
||||||
if responseProjectID != "" {
|
if responseProjectID != "" {
|
||||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; using response project ID.", responseProjectID, projectID)
|
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||||
finalProjectID = responseProjectID
|
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||||
|
strings.EqualFold(tierID, "FREE") ||
|
||||||
|
strings.EqualFold(tierID, "LEGACY")
|
||||||
|
|
||||||
|
if isFreeUser {
|
||||||
|
// Interactive prompt for free users
|
||||||
|
fmt.Printf("\nGoogle returned a different project ID:\n")
|
||||||
|
fmt.Printf(" Requested (frontend): %s\n", projectID)
|
||||||
|
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
|
||||||
|
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
|
||||||
|
fmt.Printf(" This is normal for free tier users.\n\n")
|
||||||
|
fmt.Printf("Which project ID would you like to use?\n")
|
||||||
|
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
|
||||||
|
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
|
||||||
|
fmt.Printf("Enter choice [1]: ")
|
||||||
|
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
choice, _ := reader.ReadString('\n')
|
||||||
|
choice = strings.TrimSpace(choice)
|
||||||
|
|
||||||
|
if choice == "2" {
|
||||||
|
log.Infof("Using frontend project ID: %s", projectID)
|
||||||
|
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
|
||||||
|
finalProjectID = projectID
|
||||||
|
} else {
|
||||||
|
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
|
||||||
|
finalProjectID = responseProjectID
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Pro users: keep requested project ID (original behavior)
|
||||||
|
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
finalProjectID = responseProjectID
|
finalProjectID = responseProjectID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -261,6 +261,25 @@ type PayloadModelRule struct {
|
|||||||
Protocol string `yaml:"protocol" json:"protocol"`
|
Protocol string `yaml:"protocol" json:"protocol"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloakConfig configures request cloaking for non-Claude-Code clients.
|
||||||
|
// Cloaking disguises API requests to appear as originating from the official Claude Code CLI.
|
||||||
|
type CloakConfig struct {
|
||||||
|
// Mode controls cloaking behavior: "auto" (default), "always", or "never".
|
||||||
|
// - "auto": cloak only when client is not Claude Code (based on User-Agent)
|
||||||
|
// - "always": always apply cloaking regardless of client
|
||||||
|
// - "never": never apply cloaking
|
||||||
|
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||||
|
|
||||||
|
// StrictMode controls how system prompts are handled when cloaking.
|
||||||
|
// - false (default): prepend Claude Code prompt to user system messages
|
||||||
|
// - true: strip all user system messages, keep only Claude Code prompt
|
||||||
|
StrictMode bool `yaml:"strict-mode,omitempty" json:"strict-mode,omitempty"`
|
||||||
|
|
||||||
|
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||||
|
// This can help bypass certain content filters.
|
||||||
|
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ClaudeKey represents the configuration for a Claude API key,
|
// ClaudeKey represents the configuration for a Claude API key,
|
||||||
// including the API key itself and an optional base URL for the API endpoint.
|
// including the API key itself and an optional base URL for the API endpoint.
|
||||||
type ClaudeKey struct {
|
type ClaudeKey struct {
|
||||||
@@ -289,6 +308,9 @@ type ClaudeKey struct {
|
|||||||
|
|
||||||
// ExcludedModels lists model IDs that should be excluded for this provider.
|
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||||
|
|
||||||
|
// Cloak configures request cloaking for non-Claude-Code clients.
|
||||||
|
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package logging
|
package logging
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
@@ -112,6 +113,11 @@ func isAIAPIPath(path string) bool {
|
|||||||
// - gin.HandlerFunc: A middleware handler for panic recovery
|
// - gin.HandlerFunc: A middleware handler for panic recovery
|
||||||
func GinLogrusRecovery() gin.HandlerFunc {
|
func GinLogrusRecovery() gin.HandlerFunc {
|
||||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||||
|
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||||
|
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
}
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"panic": recovered,
|
"panic": recovered,
|
||||||
"stack": string(debug.Stack()),
|
"stack": string(debug.Stack()),
|
||||||
|
|||||||
60
internal/logging/gin_logger_test.go
Normal file
60
internal/logging/gin_logger_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
engine := gin.New()
|
||||||
|
engine.Use(GinLogrusRecovery())
|
||||||
|
engine.GET("/abort", func(c *gin.Context) {
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
recovered := recover()
|
||||||
|
if recovered == nil {
|
||||||
|
t.Fatalf("expected panic, got nil")
|
||||||
|
}
|
||||||
|
err, ok := recovered.(error)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected error panic, got %T", recovered)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, http.ErrAbortHandler) {
|
||||||
|
t.Fatalf("expected ErrAbortHandler, got %v", err)
|
||||||
|
}
|
||||||
|
if err != http.ErrAbortHandler {
|
||||||
|
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
engine.ServeHTTP(recorder, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
engine := gin.New()
|
||||||
|
engine.Use(GinLogrusRecovery())
|
||||||
|
engine.GET("/panic", func(c *gin.Context) {
|
||||||
|
panic("boom")
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
engine.ServeHTTP(recorder, req)
|
||||||
|
if recorder.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
303
internal/registry/kiro_model_converter.go
Normal file
303
internal/registry/kiro_model_converter.go
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
// Package registry provides Kiro model conversion utilities.
|
||||||
|
// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format,
|
||||||
|
// and merging with static metadata for thinking support and other capabilities.
|
||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KiroAPIModel represents a model from Kiro API response.
|
||||||
|
// This is a local copy to avoid import cycles with the kiro package.
|
||||||
|
// The structure mirrors kiro.KiroModel for easy data conversion.
|
||||||
|
type KiroAPIModel struct {
|
||||||
|
// ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5")
|
||||||
|
ModelID string
|
||||||
|
// ModelName is the human-readable name
|
||||||
|
ModelName string
|
||||||
|
// Description is the model description
|
||||||
|
Description string
|
||||||
|
// RateMultiplier is the credit multiplier for this model
|
||||||
|
RateMultiplier float64
|
||||||
|
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||||
|
RateUnit string
|
||||||
|
// MaxInputTokens is the maximum input token limit
|
||||||
|
MaxInputTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models.
|
||||||
|
// All Kiro models support thinking with the following budget range.
|
||||||
|
var DefaultKiroThinkingSupport = &ThinkingSupport{
|
||||||
|
Min: 1024, // Minimum thinking budget tokens
|
||||||
|
Max: 32000, // Maximum thinking budget tokens
|
||||||
|
ZeroAllowed: true, // Allow disabling thinking with 0
|
||||||
|
DynamicAllowed: true, // Allow dynamic thinking budget (-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultKiroContextLength is the default context window size for Kiro models.
|
||||||
|
const DefaultKiroContextLength = 200000
|
||||||
|
|
||||||
|
// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models.
|
||||||
|
const DefaultKiroMaxCompletionTokens = 64000
|
||||||
|
|
||||||
|
// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format.
|
||||||
|
// It performs the following transformations:
|
||||||
|
// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5)
|
||||||
|
// - Adds default thinking support metadata
|
||||||
|
// - Sets default context length and max completion tokens if not provided
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - kiroModels: List of models from Kiro API response
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*ModelInfo: Converted model information list
|
||||||
|
func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo {
|
||||||
|
if len(kiroModels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
result := make([]*ModelInfo, 0, len(kiroModels))
|
||||||
|
|
||||||
|
for _, km := range kiroModels {
|
||||||
|
// Skip nil models
|
||||||
|
if km == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip models without valid ID
|
||||||
|
if km.ModelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize the model ID to kiro-* format
|
||||||
|
normalizedID := normalizeKiroModelID(km.ModelID)
|
||||||
|
|
||||||
|
// Create ModelInfo with converted data
|
||||||
|
info := &ModelInfo{
|
||||||
|
ID: normalizedID,
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: generateKiroDisplayName(km.ModelName, normalizedID),
|
||||||
|
Description: km.Description,
|
||||||
|
// Use MaxInputTokens from API if available, otherwise use default
|
||||||
|
ContextLength: getContextLength(km.MaxInputTokens),
|
||||||
|
MaxCompletionTokens: DefaultKiroMaxCompletionTokens,
|
||||||
|
// All Kiro models support thinking
|
||||||
|
Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAgenticVariants creates -agentic variants for each model.
|
||||||
|
// Agentic variants are optimized for coding agents with chunked writes.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - models: Base models to generate variants for
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*ModelInfo: Combined list of base models and their agentic variants
|
||||||
|
func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-allocate result with capacity for both base models and variants
|
||||||
|
result := make([]*ModelInfo, 0, len(models)*2)
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the base model first
|
||||||
|
result = append(result, model)
|
||||||
|
|
||||||
|
// Skip if model already has -agentic suffix
|
||||||
|
if strings.HasSuffix(model.ID, "-agentic") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip special models that shouldn't have agentic variants
|
||||||
|
if model.ID == "kiro-auto" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create agentic variant
|
||||||
|
agenticModel := &ModelInfo{
|
||||||
|
ID: model.ID + "-agentic",
|
||||||
|
Object: model.Object,
|
||||||
|
Created: model.Created,
|
||||||
|
OwnedBy: model.OwnedBy,
|
||||||
|
Type: model.Type,
|
||||||
|
DisplayName: model.DisplayName + " (Agentic)",
|
||||||
|
Description: generateAgenticDescription(model.Description),
|
||||||
|
ContextLength: model.ContextLength,
|
||||||
|
MaxCompletionTokens: model.MaxCompletionTokens,
|
||||||
|
Thinking: cloneThinkingSupport(model.Thinking),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, agenticModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeWithStaticMetadata merges dynamic models with static metadata.
|
||||||
|
// Static metadata takes priority for any overlapping fields.
|
||||||
|
// This allows manual overrides for specific models while keeping dynamic discovery.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - dynamicModels: Models from Kiro API (converted to ModelInfo)
|
||||||
|
// - staticModels: Predefined model metadata (from GetKiroModels())
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*ModelInfo: Merged model list with static metadata taking priority
|
||||||
|
func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo {
|
||||||
|
if len(dynamicModels) == 0 && len(staticModels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a map of static models for quick lookup
|
||||||
|
staticMap := make(map[string]*ModelInfo, len(staticModels))
|
||||||
|
for _, sm := range staticModels {
|
||||||
|
if sm != nil && sm.ID != "" {
|
||||||
|
staticMap[sm.ID] = sm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build result, preferring static metadata where available
|
||||||
|
seenIDs := make(map[string]struct{})
|
||||||
|
result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels))
|
||||||
|
|
||||||
|
// First, process dynamic models and merge with static if available
|
||||||
|
for _, dm := range dynamicModels {
|
||||||
|
if dm == nil || dm.ID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip duplicates
|
||||||
|
if _, seen := seenIDs[dm.ID]; seen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenIDs[dm.ID] = struct{}{}
|
||||||
|
|
||||||
|
// Check if static metadata exists for this model
|
||||||
|
if sm, exists := staticMap[dm.ID]; exists {
|
||||||
|
// Static metadata takes priority - use static model
|
||||||
|
result = append(result, sm)
|
||||||
|
} else {
|
||||||
|
// No static metadata - use dynamic model
|
||||||
|
result = append(result, dm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add any static models not in dynamic list
|
||||||
|
for _, sm := range staticModels {
|
||||||
|
if sm == nil || sm.ID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, seen := seenIDs[sm.ID]; seen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenIDs[sm.ID] = struct{}{}
|
||||||
|
result = append(result, sm)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeKiroModelID converts Kiro API model IDs to internal format.
|
||||||
|
// Transformation rules:
|
||||||
|
// - Adds "kiro-" prefix if not present
|
||||||
|
// - Replaces dots with hyphens (e.g., 4.5 → 4-5)
|
||||||
|
// - Handles special cases like "auto" → "kiro-auto"
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5"
|
||||||
|
// - "claude-opus-4.5" → "kiro-claude-opus-4-5"
|
||||||
|
// - "auto" → "kiro-auto"
|
||||||
|
// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged)
|
||||||
|
func normalizeKiroModelID(modelID string) string {
|
||||||
|
if modelID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim whitespace
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
|
||||||
|
// Replace dots with hyphens (e.g., 4.5 → 4-5)
|
||||||
|
normalized := strings.ReplaceAll(modelID, ".", "-")
|
||||||
|
|
||||||
|
// Add kiro- prefix if not present
|
||||||
|
if !strings.HasPrefix(normalized, "kiro-") {
|
||||||
|
normalized = "kiro-" + normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKiroDisplayName creates a human-readable display name.
|
||||||
|
// Uses the API-provided model name if available, otherwise generates from ID.
|
||||||
|
func generateKiroDisplayName(modelName, normalizedID string) string {
|
||||||
|
if modelName != "" {
|
||||||
|
return "Kiro " + modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate from normalized ID by removing kiro- prefix and formatting
|
||||||
|
displayID := strings.TrimPrefix(normalizedID, "kiro-")
|
||||||
|
// Capitalize first letter of each word
|
||||||
|
words := strings.Split(displayID, "-")
|
||||||
|
for i, word := range words {
|
||||||
|
if len(word) > 0 {
|
||||||
|
words[i] = strings.ToUpper(word[:1]) + word[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "Kiro " + strings.Join(words, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateAgenticDescription creates description for agentic variants.
|
||||||
|
func generateAgenticDescription(baseDescription string) string {
|
||||||
|
if baseDescription == "" {
|
||||||
|
return "Optimized for coding agents with chunked writes"
|
||||||
|
}
|
||||||
|
return baseDescription + " (Agentic mode: chunked writes)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// getContextLength returns the context length, using default if not provided.
|
||||||
|
func getContextLength(maxInputTokens int) int {
|
||||||
|
if maxInputTokens > 0 {
|
||||||
|
return maxInputTokens
|
||||||
|
}
|
||||||
|
return DefaultKiroContextLength
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneThinkingSupport creates a deep copy of ThinkingSupport.
|
||||||
|
// Returns nil if input is nil.
|
||||||
|
func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport {
|
||||||
|
if ts == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
clone := &ThinkingSupport{
|
||||||
|
Min: ts.Min,
|
||||||
|
Max: ts.Max,
|
||||||
|
ZeroAllowed: ts.ZeroAllowed,
|
||||||
|
DynamicAllowed: ts.DynamicAllowed,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deep copy Levels slice if present
|
||||||
|
if len(ts.Levels) > 0 {
|
||||||
|
clone.Levels = make([]string, len(ts.Levels))
|
||||||
|
copy(clone.Levels, ts.Levels)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone
|
||||||
|
}
|
||||||
@@ -287,6 +287,67 @@ func GetGeminiVertexModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
},
|
},
|
||||||
|
// Imagen image generation models - use :predict action
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Generate",
|
||||||
|
Description: "Imagen 4.0 image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-ultra-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-ultra-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Ultra Generate",
|
||||||
|
Description: "Imagen 4.0 Ultra high-quality image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-3.0-generate-002",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1740000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-3.0-generate-002",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Imagen 3.0 Generate",
|
||||||
|
Description: "Imagen 3.0 image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-3.0-fast-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1740000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-3.0-fast-generate-001",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Imagen 3.0 Fast Generate",
|
||||||
|
Description: "Imagen 3.0 fast image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-fast-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-fast-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Fast Generate",
|
||||||
|
Description: "Imagen 4.0 fast image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -765,21 +826,23 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
type AntigravityModelConfig struct {
|
type AntigravityModelConfig struct {
|
||||||
Thinking *ThinkingSupport
|
Thinking *ThinkingSupport
|
||||||
MaxCompletionTokens int
|
MaxCompletionTokens int
|
||||||
Name string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||||
// Keys use upstream model names returned by the Antigravity models endpoint.
|
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||||
return map[string]*AntigravityModelConfig{
|
return map[string]*AntigravityModelConfig{
|
||||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/rev19-uic3-1p"},
|
"rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-high"},
|
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image"},
|
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash"},
|
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||||
|
"gpt-oss-120b-medium": {},
|
||||||
|
"tab_flash_lite_preview": {},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -809,10 +872,9 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check Antigravity static config
|
// Check Antigravity static config
|
||||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil && cfg.Thinking != nil {
|
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
||||||
return &ModelInfo{
|
return &ModelInfo{
|
||||||
ID: modelID,
|
ID: modelID,
|
||||||
Name: cfg.Name,
|
|
||||||
Thinking: cfg.Thinking,
|
Thinking: cfg.Thinking,
|
||||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,8 @@ type ThinkingSupport struct {
|
|||||||
type ModelRegistration struct {
|
type ModelRegistration struct {
|
||||||
// Info contains the model metadata
|
// Info contains the model metadata
|
||||||
Info *ModelInfo
|
Info *ModelInfo
|
||||||
|
// InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities.
|
||||||
|
InfoByProvider map[string]*ModelInfo
|
||||||
// Count is the number of active clients that can provide this model
|
// Count is the number of active clients that can provide this model
|
||||||
Count int
|
Count int
|
||||||
// LastUpdated tracks when this registration was last modified
|
// LastUpdated tracks when this registration was last modified
|
||||||
@@ -134,16 +136,19 @@ func GetGlobalRegistry() *ModelRegistry {
|
|||||||
return globalRegistry
|
return globalRegistry
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupModelInfo searches the dynamic registry first, then falls back to static model definitions.
|
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||||
//
|
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||||
// This helper exists because some code paths only have a model ID and still need Thinking and
|
|
||||||
// max completion token metadata even when the dynamic registry hasn't been populated.
|
|
||||||
func LookupModelInfo(modelID string) *ModelInfo {
|
|
||||||
modelID = strings.TrimSpace(modelID)
|
modelID = strings.TrimSpace(modelID)
|
||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if info := GetGlobalRegistry().GetModelInfo(modelID); info != nil {
|
|
||||||
|
p := ""
|
||||||
|
if len(provider) > 0 {
|
||||||
|
p = strings.ToLower(strings.TrimSpace(provider[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
return LookupStaticModelInfo(modelID)
|
return LookupStaticModelInfo(modelID)
|
||||||
@@ -299,6 +304,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
if count, okProv := reg.Providers[oldProvider]; okProv {
|
if count, okProv := reg.Providers[oldProvider]; okProv {
|
||||||
if count <= toRemove {
|
if count <= toRemove {
|
||||||
delete(reg.Providers, oldProvider)
|
delete(reg.Providers, oldProvider)
|
||||||
|
if reg.InfoByProvider != nil {
|
||||||
|
delete(reg.InfoByProvider, oldProvider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
reg.Providers[oldProvider] = count - toRemove
|
reg.Providers[oldProvider] = count - toRemove
|
||||||
}
|
}
|
||||||
@@ -348,6 +356,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
model := newModels[id]
|
model := newModels[id]
|
||||||
if reg, ok := r.models[id]; ok {
|
if reg, ok := r.models[id]; ok {
|
||||||
reg.Info = cloneModelInfo(model)
|
reg.Info = cloneModelInfo(model)
|
||||||
|
if provider != "" {
|
||||||
|
if reg.InfoByProvider == nil {
|
||||||
|
reg.InfoByProvider = make(map[string]*ModelInfo)
|
||||||
|
}
|
||||||
|
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
|
}
|
||||||
reg.LastUpdated = now
|
reg.LastUpdated = now
|
||||||
if reg.QuotaExceededClients != nil {
|
if reg.QuotaExceededClients != nil {
|
||||||
delete(reg.QuotaExceededClients, clientID)
|
delete(reg.QuotaExceededClients, clientID)
|
||||||
@@ -411,11 +425,15 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
if existing.SuspendedClients == nil {
|
if existing.SuspendedClients == nil {
|
||||||
existing.SuspendedClients = make(map[string]string)
|
existing.SuspendedClients = make(map[string]string)
|
||||||
}
|
}
|
||||||
|
if existing.InfoByProvider == nil {
|
||||||
|
existing.InfoByProvider = make(map[string]*ModelInfo)
|
||||||
|
}
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
if existing.Providers == nil {
|
if existing.Providers == nil {
|
||||||
existing.Providers = make(map[string]int)
|
existing.Providers = make(map[string]int)
|
||||||
}
|
}
|
||||||
existing.Providers[provider]++
|
existing.Providers[provider]++
|
||||||
|
existing.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
}
|
}
|
||||||
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
||||||
return
|
return
|
||||||
@@ -423,6 +441,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
|
|
||||||
registration := &ModelRegistration{
|
registration := &ModelRegistration{
|
||||||
Info: cloneModelInfo(model),
|
Info: cloneModelInfo(model),
|
||||||
|
InfoByProvider: make(map[string]*ModelInfo),
|
||||||
Count: 1,
|
Count: 1,
|
||||||
LastUpdated: now,
|
LastUpdated: now,
|
||||||
QuotaExceededClients: make(map[string]*time.Time),
|
QuotaExceededClients: make(map[string]*time.Time),
|
||||||
@@ -430,6 +449,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
}
|
}
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
registration.Providers = map[string]int{provider: 1}
|
registration.Providers = map[string]int{provider: 1}
|
||||||
|
registration.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
}
|
}
|
||||||
r.models[modelID] = registration
|
r.models[modelID] = registration
|
||||||
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
||||||
@@ -455,6 +475,9 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
|
|||||||
if count, ok := registration.Providers[provider]; ok {
|
if count, ok := registration.Providers[provider]; ok {
|
||||||
if count <= 1 {
|
if count <= 1 {
|
||||||
delete(registration.Providers, provider)
|
delete(registration.Providers, provider)
|
||||||
|
if registration.InfoByProvider != nil {
|
||||||
|
delete(registration.InfoByProvider, provider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
registration.Providers[provider] = count - 1
|
registration.Providers[provider] = count - 1
|
||||||
}
|
}
|
||||||
@@ -539,6 +562,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
|||||||
if count, ok := registration.Providers[provider]; ok {
|
if count, ok := registration.Providers[provider]; ok {
|
||||||
if count <= 1 {
|
if count <= 1 {
|
||||||
delete(registration.Providers, provider)
|
delete(registration.Providers, provider)
|
||||||
|
if registration.InfoByProvider != nil {
|
||||||
|
delete(registration.InfoByProvider, provider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
registration.Providers[provider] = count - 1
|
registration.Providers[provider] = count - 1
|
||||||
}
|
}
|
||||||
@@ -945,12 +971,22 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelInfo returns the registered ModelInfo for the given model ID, if present.
|
// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available.
|
||||||
// Returns nil if the model is unknown to the registry.
|
func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
||||||
func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo {
|
|
||||||
r.mutex.RLock()
|
r.mutex.RLock()
|
||||||
defer r.mutex.RUnlock()
|
defer r.mutex.RUnlock()
|
||||||
if reg, ok := r.models[modelID]; ok && reg != nil {
|
if reg, ok := r.models[modelID]; ok && reg != nil {
|
||||||
|
// Try provider specific definition first
|
||||||
|
if provider != "" && reg.InfoByProvider != nil {
|
||||||
|
if reg.Providers != nil {
|
||||||
|
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||||
|
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback to global info (last registered)
|
||||||
return reg.Info
|
return reg.Info
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -1006,10 +1042,10 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
"owned_by": model.OwnedBy,
|
"owned_by": model.OwnedBy,
|
||||||
}
|
}
|
||||||
if model.Created > 0 {
|
if model.Created > 0 {
|
||||||
result["created"] = model.Created
|
result["created_at"] = model.Created
|
||||||
}
|
}
|
||||||
if model.Type != "" {
|
if model.Type != "" {
|
||||||
result["type"] = model.Type
|
result["type"] = "model"
|
||||||
}
|
}
|
||||||
if model.DisplayName != "" {
|
if model.DisplayName != "" {
|
||||||
result["display_name"] = model.DisplayName
|
result["display_name"] = model.DisplayName
|
||||||
|
|||||||
@@ -393,12 +393,13 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
|||||||
}
|
}
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, translatedPayload{}, err
|
return nil, translatedPayload{}, err
|
||||||
}
|
}
|
||||||
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
||||||
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||||
|
|||||||
@@ -137,97 +137,119 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
var lastStatus int
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
var lastBody []byte
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
attemptLoop:
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
if errReq != nil {
|
var lastStatus int
|
||||||
err = errReq
|
var lastBody []byte
|
||||||
return resp, err
|
var lastErr error
|
||||||
}
|
|
||||||
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
for idx, baseURL := range baseURLs {
|
||||||
if errDo != nil {
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
if errReq != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
err = errReq
|
||||||
return resp, errDo
|
return resp, err
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
|
||||||
lastBody = nil
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
lastErr = errDo
|
if errDo != nil {
|
||||||
if idx+1 < len(baseURLs) {
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
continue
|
return resp, errDo
|
||||||
|
}
|
||||||
|
lastStatus = 0
|
||||||
|
lastBody = nil
|
||||||
|
lastErr = errDo
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errDo
|
||||||
|
return resp, err
|
||||||
}
|
}
|
||||||
err = errDo
|
|
||||||
return resp, err
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
err = errRead
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
|
||||||
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||||
|
lastStatus = httpResp.StatusCode
|
||||||
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
|
lastErr = nil
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attempt+1 < attempts {
|
||||||
|
delay := antigravityNoCapacityRetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return resp, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||||
|
sErr.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = sErr
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||||
|
var param any
|
||||||
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
switch {
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
case lastStatus != 0:
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
if lastStatus == http.StatusTooManyRequests {
|
||||||
}
|
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
err = errRead
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
|
||||||
|
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
|
||||||
lastStatus = httpResp.StatusCode
|
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
|
||||||
lastErr = nil
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
|
||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
sErr.retryAfter = retryAfter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sErr
|
err = sErr
|
||||||
return resp, err
|
case lastErr != nil:
|
||||||
|
err = lastErr
|
||||||
|
default:
|
||||||
|
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||||
}
|
}
|
||||||
|
return resp, err
|
||||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
|
||||||
var param any
|
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
|
||||||
reporter.ensurePublished(ctx)
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
|
||||||
case lastStatus != 0:
|
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
|
||||||
err = lastErr
|
|
||||||
default:
|
|
||||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
|
||||||
}
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,160 +278,182 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
var lastStatus int
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
var lastBody []byte
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
attemptLoop:
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
if errReq != nil {
|
var lastStatus int
|
||||||
err = errReq
|
var lastBody []byte
|
||||||
return resp, err
|
var lastErr error
|
||||||
}
|
|
||||||
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
for idx, baseURL := range baseURLs {
|
||||||
if errDo != nil {
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
if errReq != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
err = errReq
|
||||||
return resp, errDo
|
return resp, err
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
|
||||||
lastBody = nil
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
lastErr = errDo
|
if errDo != nil {
|
||||||
if idx+1 < len(baseURLs) {
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
continue
|
return resp, errDo
|
||||||
}
|
|
||||||
err = errDo
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
|
||||||
err = errRead
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
if errCtx := ctx.Err(); errCtx != nil {
|
|
||||||
err = errCtx
|
|
||||||
return resp, err
|
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
lastStatus = 0
|
||||||
lastBody = nil
|
lastBody = nil
|
||||||
lastErr = errRead
|
lastErr = errDo
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = errRead
|
err = errDo
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
lastStatus = httpResp.StatusCode
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
lastErr = nil
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
}
|
||||||
continue
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||||
|
err = errRead
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
if errCtx := ctx.Err(); errCtx != nil {
|
||||||
|
err = errCtx
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
lastStatus = 0
|
||||||
|
lastBody = nil
|
||||||
|
lastErr = errRead
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errRead
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
lastStatus = httpResp.StatusCode
|
||||||
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
|
lastErr = nil
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attempt+1 < attempts {
|
||||||
|
delay := antigravityNoCapacityRetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return resp, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||||
|
sErr.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = sErr
|
||||||
|
return resp, err
|
||||||
}
|
}
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
go func(resp *http.Response) {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
|
// Filter usage metadata for all models
|
||||||
|
// Only retain usage statistics in the terminal chunk
|
||||||
|
line = FilterSSEUsageMetadata(line)
|
||||||
|
|
||||||
|
payload := jsonPayload(line)
|
||||||
|
if payload == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
} else {
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
}
|
||||||
|
}(httpResp)
|
||||||
|
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
for chunk := range out {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
return resp, chunk.Err
|
||||||
|
}
|
||||||
|
if len(chunk.Payload) > 0 {
|
||||||
|
_, _ = buffer.Write(chunk.Payload)
|
||||||
|
_, _ = buffer.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
||||||
|
|
||||||
|
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||||
|
var param any
|
||||||
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case lastStatus != 0:
|
||||||
|
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||||
|
if lastStatus == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||||
sErr.retryAfter = retryAfter
|
sErr.retryAfter = retryAfter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sErr
|
err = sErr
|
||||||
return resp, err
|
case lastErr != nil:
|
||||||
|
err = lastErr
|
||||||
|
default:
|
||||||
|
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||||
}
|
}
|
||||||
|
return resp, err
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
|
||||||
go func(resp *http.Response) {
|
|
||||||
defer close(out)
|
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Buffer(nil, streamScannerBuffer)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Bytes()
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
||||||
|
|
||||||
// Filter usage metadata for all models
|
|
||||||
// Only retain usage statistics in the terminal chunk
|
|
||||||
line = FilterSSEUsageMetadata(line)
|
|
||||||
|
|
||||||
payload := jsonPayload(line)
|
|
||||||
if payload == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
|
||||||
reporter.publish(ctx, detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
|
||||||
}
|
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
||||||
reporter.publishFailure(ctx)
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
||||||
} else {
|
|
||||||
reporter.ensurePublished(ctx)
|
|
||||||
}
|
|
||||||
}(httpResp)
|
|
||||||
|
|
||||||
var buffer bytes.Buffer
|
|
||||||
for chunk := range out {
|
|
||||||
if chunk.Err != nil {
|
|
||||||
return resp, chunk.Err
|
|
||||||
}
|
|
||||||
if len(chunk.Payload) > 0 {
|
|
||||||
_, _ = buffer.Write(chunk.Payload)
|
|
||||||
_, _ = buffer.Write([]byte("\n"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
|
||||||
|
|
||||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
|
||||||
var param any
|
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
|
||||||
reporter.ensurePublished(ctx)
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
|
||||||
case lastStatus != 0:
|
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
|
||||||
err = lastErr
|
|
||||||
default:
|
|
||||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
|
||||||
}
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -622,149 +666,171 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
var lastStatus int
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
var lastBody []byte
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
attemptLoop:
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
if errReq != nil {
|
var lastStatus int
|
||||||
err = errReq
|
var lastBody []byte
|
||||||
return nil, err
|
var lastErr error
|
||||||
}
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
for idx, baseURL := range baseURLs {
|
||||||
if errDo != nil {
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
if errReq != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
err = errReq
|
||||||
return nil, errDo
|
return nil, err
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
lastBody = nil
|
if errDo != nil {
|
||||||
lastErr = errDo
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
if idx+1 < len(baseURLs) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
return nil, errDo
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = errDo
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
|
||||||
err = errRead
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if errCtx := ctx.Err(); errCtx != nil {
|
|
||||||
err = errCtx
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
lastStatus = 0
|
||||||
lastBody = nil
|
lastBody = nil
|
||||||
lastErr = errRead
|
lastErr = errDo
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = errRead
|
err = errDo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
lastStatus = httpResp.StatusCode
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
lastErr = nil
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
}
|
||||||
continue
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||||
|
err = errRead
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if errCtx := ctx.Err(); errCtx != nil {
|
||||||
|
err = errCtx
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lastStatus = 0
|
||||||
|
lastBody = nil
|
||||||
|
lastErr = errRead
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errRead
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
lastStatus = httpResp.StatusCode
|
||||||
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
|
lastErr = nil
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attempt+1 < attempts {
|
||||||
|
delay := antigravityNoCapacityRetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return nil, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||||
|
sErr.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = sErr
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
stream = out
|
||||||
|
go func(resp *http.Response) {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
|
var param any
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
|
// Filter usage metadata for all models
|
||||||
|
// Only retain usage statistics in the terminal chunk
|
||||||
|
line = FilterSSEUsageMetadata(line)
|
||||||
|
|
||||||
|
payload := jsonPayload(line)
|
||||||
|
if payload == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||||
|
for i := range chunks {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||||
|
for i := range tail {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
} else {
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
}
|
||||||
|
}(httpResp)
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case lastStatus != 0:
|
||||||
|
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||||
|
if lastStatus == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||||
sErr.retryAfter = retryAfter
|
sErr.retryAfter = retryAfter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sErr
|
err = sErr
|
||||||
return nil, err
|
case lastErr != nil:
|
||||||
|
err = lastErr
|
||||||
|
default:
|
||||||
|
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||||
}
|
}
|
||||||
|
return nil, err
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
|
||||||
stream = out
|
|
||||||
go func(resp *http.Response) {
|
|
||||||
defer close(out)
|
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Buffer(nil, streamScannerBuffer)
|
|
||||||
var param any
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Bytes()
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
||||||
|
|
||||||
// Filter usage metadata for all models
|
|
||||||
// Only retain usage statistics in the terminal chunk
|
|
||||||
line = FilterSSEUsageMetadata(line)
|
|
||||||
|
|
||||||
payload := jsonPayload(line)
|
|
||||||
if payload == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
|
||||||
reporter.publish(ctx, detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
|
||||||
for i := range chunks {
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
|
||||||
for i := range tail {
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
|
||||||
}
|
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
||||||
reporter.publishFailure(ctx)
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
||||||
} else {
|
|
||||||
reporter.ensurePublished(ctx)
|
|
||||||
}
|
|
||||||
}(httpResp)
|
|
||||||
return stream, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
|
||||||
case lastStatus != 0:
|
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
|
||||||
err = lastErr
|
|
||||||
default:
|
|
||||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -802,7 +868,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
// Prepare payload once (doesn't depend on baseURL)
|
// Prepare payload once (doesn't depend on baseURL)
|
||||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
@@ -994,7 +1060,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
modelConfig := registry.GetAntigravityModelConfig()
|
modelConfig := registry.GetAntigravityModelConfig()
|
||||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||||
for originalName := range result.Map() {
|
for originalName, modelData := range result.Map() {
|
||||||
modelID := strings.TrimSpace(originalName)
|
modelID := strings.TrimSpace(originalName)
|
||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
continue
|
continue
|
||||||
@@ -1004,15 +1070,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
modelCfg := modelConfig[modelID]
|
modelCfg := modelConfig[modelID]
|
||||||
modelName := modelID
|
|
||||||
if modelCfg != nil && modelCfg.Name != "" {
|
// Extract displayName from upstream response, fallback to modelID
|
||||||
modelName = modelCfg.Name
|
displayName := modelData.Get("displayName").String()
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = modelID
|
||||||
}
|
}
|
||||||
|
|
||||||
modelInfo := ®istry.ModelInfo{
|
modelInfo := ®istry.ModelInfo{
|
||||||
ID: modelID,
|
ID: modelID,
|
||||||
Name: modelName,
|
Name: modelID,
|
||||||
Description: modelID,
|
Description: displayName,
|
||||||
DisplayName: modelID,
|
DisplayName: displayName,
|
||||||
Version: modelID,
|
Version: modelID,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: now,
|
Created: now,
|
||||||
@@ -1205,7 +1274,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||||
|
|
||||||
if strings.Contains(modelName, "claude") {
|
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||||
strJSON := string(payload)
|
strJSON := string(payload)
|
||||||
paths := make([]string, 0)
|
paths := make([]string, 0)
|
||||||
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
||||||
@@ -1216,7 +1285,17 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||||
// const->enum conversion, and flattening of types/anyOf.
|
// const->enum conversion, and flattening of types/anyOf.
|
||||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
||||||
|
payload = []byte(strJSON)
|
||||||
|
} else {
|
||||||
|
strJSON := string(payload)
|
||||||
|
paths := make([]string, 0)
|
||||||
|
util.Walk(gjson.Parse(strJSON), "", "parametersJsonSchema", &paths)
|
||||||
|
for _, p := range paths {
|
||||||
|
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||||
|
}
|
||||||
|
// Clean tool schemas for Gemini to remove unsupported JSON Schema keywords
|
||||||
|
// without adding empty-schema placeholders.
|
||||||
|
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||||
payload = []byte(strJSON)
|
payload = []byte(strJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1233,6 +1312,12 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.Contains(modelName, "claude") {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||||
|
} else {
|
||||||
|
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
|
||||||
|
}
|
||||||
|
|
||||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
return nil, errReq
|
return nil, errReq
|
||||||
@@ -1362,14 +1447,70 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
|
|||||||
return defaultAntigravityAgent
|
return defaultAntigravityAgent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
|
||||||
|
retry := 0
|
||||||
|
if cfg != nil {
|
||||||
|
retry = cfg.RequestRetry
|
||||||
|
}
|
||||||
|
if auth != nil {
|
||||||
|
if override, ok := auth.RequestRetryOverride(); ok {
|
||||||
|
retry = override
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if retry < 0 {
|
||||||
|
retry = 0
|
||||||
|
}
|
||||||
|
attempts := retry + 1
|
||||||
|
if attempts < 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return attempts
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
|
||||||
|
if statusCode != http.StatusServiceUnavailable {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(string(body))
|
||||||
|
return strings.Contains(msg, "no capacity available")
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
|
||||||
|
if attempt < 0 {
|
||||||
|
attempt = 0
|
||||||
|
}
|
||||||
|
delay := time.Duration(attempt+1) * 250 * time.Millisecond
|
||||||
|
if delay > 2*time.Second {
|
||||||
|
delay = 2 * time.Second
|
||||||
|
}
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityWait(ctx context.Context, wait time.Duration) error {
|
||||||
|
if wait <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(wait)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
||||||
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
||||||
return []string{base}
|
return []string{base}
|
||||||
}
|
}
|
||||||
return []string{
|
return []string{
|
||||||
antigravitySandboxBaseURLDaily,
|
|
||||||
antigravityBaseURLDaily,
|
antigravityBaseURLDaily,
|
||||||
antigravityBaseURLProd,
|
antigravitySandboxBaseURLDaily,
|
||||||
|
// antigravityBaseURLProd,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1408,31 +1549,10 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
|||||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||||
|
|
||||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
||||||
|
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
|
||||||
if !strings.HasPrefix(modelName, "gemini-3-") {
|
template, _ = sjson.Delete(template, "toolConfig")
|
||||||
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
|
|
||||||
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
|
|
||||||
template, _ = sjson.Set(template, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(modelName, "claude") {
|
|
||||||
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
|
||||||
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
|
||||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
|
||||||
template, _ = sjson.SetRaw(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()), funcDecl.Get("parametersJsonSchema").Raw)
|
|
||||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters.$schema", key.Int(), funKey.Int()))
|
|
||||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parametersJsonSchema", key.Int(), funKey.Int()))
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
|
|
||||||
}
|
|
||||||
|
|
||||||
return []byte(template)
|
return []byte(template)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -106,22 +105,21 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(baseModel, "claude-3-5-haiku") {
|
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||||
body = checkSystemInstructions(body)
|
// based on client type and configuration.
|
||||||
}
|
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||||
body = disableThinkingIfToolChoiceForced(body)
|
body = disableThinkingIfToolChoiceForced(body)
|
||||||
|
|
||||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
|
||||||
body = ensureMaxTokensForThinking(baseModel, body)
|
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
@@ -165,7 +163,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("response body close error: %v", errClose)
|
log.Errorf("response body close error: %v", errClose)
|
||||||
@@ -239,20 +237,21 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = checkSystemInstructions(body)
|
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
// based on client type and configuration.
|
||||||
|
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||||
|
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||||
body = disableThinkingIfToolChoiceForced(body)
|
body = disableThinkingIfToolChoiceForced(body)
|
||||||
|
|
||||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
|
||||||
body = ensureMaxTokensForThinking(baseModel, body)
|
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
@@ -296,7 +295,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("response body close error: %v", errClose)
|
log.Errorf("response body close error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -541,81 +540,6 @@ func disableThinkingIfToolChoiceForced(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
|
|
||||||
// Anthropic API requires this constraint; violating it returns a 400 error.
|
|
||||||
// This function should be called after all thinking configuration is finalized.
|
|
||||||
// It looks up the model's MaxCompletionTokens from the registry to use as the cap.
|
|
||||||
func ensureMaxTokensForThinking(modelName string, body []byte) []byte {
|
|
||||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
|
||||||
if thinkingType != "enabled" {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
budgetTokens := gjson.GetBytes(body, "thinking.budget_tokens").Int()
|
|
||||||
if budgetTokens <= 0 {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
maxTokens := gjson.GetBytes(body, "max_tokens").Int()
|
|
||||||
|
|
||||||
// Look up the model's max completion tokens from the registry
|
|
||||||
maxCompletionTokens := 0
|
|
||||||
if modelInfo := registry.LookupModelInfo(modelName); modelInfo != nil {
|
|
||||||
maxCompletionTokens = modelInfo.MaxCompletionTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to budget + buffer if registry lookup fails or returns 0
|
|
||||||
const fallbackBuffer = 4000
|
|
||||||
requiredMaxTokens := budgetTokens + fallbackBuffer
|
|
||||||
if maxCompletionTokens > 0 {
|
|
||||||
requiredMaxTokens = int64(maxCompletionTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxTokens < requiredMaxTokens {
|
|
||||||
body, _ = sjson.SetBytes(body, "max_tokens", requiredMaxTokens)
|
|
||||||
}
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ClaudeExecutor) resolveClaudeConfig(auth *cliproxyauth.Auth) *config.ClaudeKey {
|
|
||||||
if auth == nil || e.cfg == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var attrKey, attrBase string
|
|
||||||
if auth.Attributes != nil {
|
|
||||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
|
||||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
|
||||||
}
|
|
||||||
for i := range e.cfg.ClaudeKey {
|
|
||||||
entry := &e.cfg.ClaudeKey[i]
|
|
||||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
|
||||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
|
||||||
if attrKey != "" && attrBase != "" {
|
|
||||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
|
||||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if attrKey != "" {
|
|
||||||
for i := range e.cfg.ClaudeKey {
|
|
||||||
entry := &e.cfg.ClaudeKey[i]
|
|
||||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type compositeReadCloser struct {
|
type compositeReadCloser struct {
|
||||||
io.Reader
|
io.Reader
|
||||||
closers []func() error
|
closers []func() error
|
||||||
@@ -809,6 +733,11 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
|||||||
|
|
||||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||||
|
// Skip built-in tools (web_search, code_execution, etc.) which have
|
||||||
|
// a "type" field and require their name to remain unchanged.
|
||||||
|
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
name := tool.Get("name").String()
|
name := tool.Get("name").String()
|
||||||
if name == "" || strings.HasPrefix(name, prefix) {
|
if name == "" || strings.HasPrefix(name, prefix) {
|
||||||
return true
|
return true
|
||||||
@@ -901,3 +830,163 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
|||||||
}
|
}
|
||||||
return updated
|
return updated
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getClientUserAgent extracts the client User-Agent from the gin context.
|
||||||
|
func getClientUserAgent(ctx context.Context) string {
|
||||||
|
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||||
|
return ginCtx.GetHeader("User-Agent")
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
||||||
|
// Returns (cloakMode, strictMode, sensitiveWords).
|
||||||
|
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
||||||
|
if auth == nil || auth.Attributes == nil {
|
||||||
|
return "auto", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cloakMode := auth.Attributes["cloak_mode"]
|
||||||
|
if cloakMode == "" {
|
||||||
|
cloakMode = "auto"
|
||||||
|
}
|
||||||
|
|
||||||
|
strictMode := strings.ToLower(auth.Attributes["cloak_strict_mode"]) == "true"
|
||||||
|
|
||||||
|
var sensitiveWords []string
|
||||||
|
if wordsStr := auth.Attributes["cloak_sensitive_words"]; wordsStr != "" {
|
||||||
|
sensitiveWords = strings.Split(wordsStr, ",")
|
||||||
|
for i := range sensitiveWords {
|
||||||
|
sensitiveWords[i] = strings.TrimSpace(sensitiveWords[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cloakMode, strictMode, sensitiveWords
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||||
|
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, baseURL := claudeCreds(auth)
|
||||||
|
if apiKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cfg.ClaudeKey {
|
||||||
|
entry := &cfg.ClaudeKey[i]
|
||||||
|
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||||
|
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||||
|
|
||||||
|
// Match by API key
|
||||||
|
if strings.EqualFold(cfgKey, apiKey) {
|
||||||
|
// If baseURL is specified, also check it
|
||||||
|
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return entry.Cloak
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
||||||
|
func injectFakeUserID(payload []byte) []byte {
|
||||||
|
metadata := gjson.GetBytes(payload, "metadata")
|
||||||
|
if !metadata.Exists() {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||||
|
if existingUserID == "" || !isValidUserID(existingUserID) {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSystemInstructionsWithMode injects Claude Code system prompt.
|
||||||
|
// In strict mode, it replaces all user system messages.
|
||||||
|
// In non-strict mode (default), it prepends to existing system messages.
|
||||||
|
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||||
|
system := gjson.GetBytes(payload, "system")
|
||||||
|
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
||||||
|
|
||||||
|
if strictMode {
|
||||||
|
// Strict mode: replace all system messages with Claude Code prompt only
|
||||||
|
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-strict mode (default): prepend Claude Code prompt to existing system messages
|
||||||
|
if system.IsArray() {
|
||||||
|
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
||||||
|
system.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("type").String() == "text" {
|
||||||
|
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyCloaking applies cloaking transformations to the payload based on config and client.
|
||||||
|
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
||||||
|
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
|
||||||
|
clientUserAgent := getClientUserAgent(ctx)
|
||||||
|
|
||||||
|
// Get cloak config from ClaudeKey configuration
|
||||||
|
cloakCfg := resolveClaudeKeyCloakConfig(cfg, auth)
|
||||||
|
|
||||||
|
// Determine cloak settings
|
||||||
|
var cloakMode string
|
||||||
|
var strictMode bool
|
||||||
|
var sensitiveWords []string
|
||||||
|
|
||||||
|
if cloakCfg != nil {
|
||||||
|
cloakMode = cloakCfg.Mode
|
||||||
|
strictMode = cloakCfg.StrictMode
|
||||||
|
sensitiveWords = cloakCfg.SensitiveWords
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to auth attributes if no config found
|
||||||
|
if cloakMode == "" {
|
||||||
|
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
|
||||||
|
cloakMode = attrMode
|
||||||
|
if !strictMode {
|
||||||
|
strictMode = attrStrict
|
||||||
|
}
|
||||||
|
if len(sensitiveWords) == 0 {
|
||||||
|
sensitiveWords = attrWords
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if cloaking should be applied
|
||||||
|
if !shouldCloak(cloakMode, clientUserAgent) {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip system instructions for claude-3-5-haiku models
|
||||||
|
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||||
|
payload = checkSystemInstructionsWithMode(payload, strictMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject fake user ID
|
||||||
|
payload = injectFakeUserID(payload)
|
||||||
|
|
||||||
|
// Apply sensitive word obfuscation
|
||||||
|
if len(sensitiveWords) > 0 {
|
||||||
|
matcher := buildSensitiveWordMatcher(sensitiveWords)
|
||||||
|
payload = obfuscateSensitiveWords(payload, matcher)
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||||
|
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||||
|
t.Fatalf("built-in tool name should not be prefixed: tools.0.name = %q, want %q", got, "web_search")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_my_custom_tool" {
|
||||||
|
t.Fatalf("custom tool should be prefixed: tools.1.name = %q, want %q", got, "proxy_my_custom_tool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
|
|||||||
176
internal/runtime/executor/cloak_obfuscate.go
Normal file
176
internal/runtime/executor/cloak_obfuscate.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// zeroWidthSpace is the Unicode zero-width space character used for obfuscation.
|
||||||
|
const zeroWidthSpace = "\u200B"
|
||||||
|
|
||||||
|
// SensitiveWordMatcher holds the compiled regex for matching sensitive words.
|
||||||
|
type SensitiveWordMatcher struct {
|
||||||
|
regex *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSensitiveWordMatcher compiles a regex from the word list.
|
||||||
|
// Words are sorted by length (longest first) for proper matching.
|
||||||
|
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
|
||||||
|
if len(words) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter and normalize words
|
||||||
|
var validWords []string
|
||||||
|
for _, w := range words {
|
||||||
|
w = strings.TrimSpace(w)
|
||||||
|
if utf8.RuneCountInString(w) >= 2 && !strings.Contains(w, zeroWidthSpace) {
|
||||||
|
validWords = append(validWords, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validWords) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by length (longest first) for proper matching
|
||||||
|
sort.Slice(validWords, func(i, j int) bool {
|
||||||
|
return len(validWords[i]) > len(validWords[j])
|
||||||
|
})
|
||||||
|
|
||||||
|
// Escape and join
|
||||||
|
escaped := make([]string, len(validWords))
|
||||||
|
for i, w := range validWords {
|
||||||
|
escaped[i] = regexp.QuoteMeta(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern := "(?i)" + strings.Join(escaped, "|")
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SensitiveWordMatcher{regex: re}
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateWord inserts a zero-width space after the first grapheme.
|
||||||
|
func obfuscateWord(word string) string {
|
||||||
|
if strings.Contains(word, zeroWidthSpace) {
|
||||||
|
return word
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get first rune
|
||||||
|
r, size := utf8.DecodeRuneInString(word)
|
||||||
|
if r == utf8.RuneError || size >= len(word) {
|
||||||
|
return word
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(r) + zeroWidthSpace + word[size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateText replaces all sensitive words in the text.
|
||||||
|
func (m *SensitiveWordMatcher) obfuscateText(text string) string {
|
||||||
|
if m == nil || m.regex == nil {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
|
||||||
|
// in system blocks and message content.
|
||||||
|
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||||
|
if matcher == nil || matcher.regex == nil {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obfuscate in system blocks
|
||||||
|
payload = obfuscateSystemBlocks(payload, matcher)
|
||||||
|
|
||||||
|
// Obfuscate in messages
|
||||||
|
payload = obfuscateMessages(payload, matcher)
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateSystemBlocks obfuscates sensitive words in system blocks.
|
||||||
|
func obfuscateSystemBlocks(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||||
|
system := gjson.GetBytes(payload, "system")
|
||||||
|
if !system.Exists() {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
if system.IsArray() {
|
||||||
|
modified := false
|
||||||
|
system.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
if value.Get("type").String() == "text" {
|
||||||
|
text := value.Get("text").String()
|
||||||
|
obfuscated := matcher.obfuscateText(text)
|
||||||
|
if obfuscated != text {
|
||||||
|
path := "system." + key.String() + ".text"
|
||||||
|
payload, _ = sjson.SetBytes(payload, path, obfuscated)
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if modified {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
} else if system.Type == gjson.String {
|
||||||
|
text := system.String()
|
||||||
|
obfuscated := matcher.obfuscateText(text)
|
||||||
|
if obfuscated != text {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "system", obfuscated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateMessages obfuscates sensitive words in message content.
|
||||||
|
func obfuscateMessages(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||||
|
messages := gjson.GetBytes(payload, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.ForEach(func(msgKey, msg gjson.Result) bool {
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msgPath := "messages." + msgKey.String()
|
||||||
|
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
// Simple string content
|
||||||
|
text := content.String()
|
||||||
|
obfuscated := matcher.obfuscateText(text)
|
||||||
|
if obfuscated != text {
|
||||||
|
payload, _ = sjson.SetBytes(payload, msgPath+".content", obfuscated)
|
||||||
|
}
|
||||||
|
} else if content.IsArray() {
|
||||||
|
// Array of content blocks
|
||||||
|
content.ForEach(func(blockKey, block gjson.Result) bool {
|
||||||
|
if block.Get("type").String() == "text" {
|
||||||
|
text := block.Get("text").String()
|
||||||
|
obfuscated := matcher.obfuscateText(text)
|
||||||
|
if obfuscated != text {
|
||||||
|
path := msgPath + ".content." + blockKey.String() + ".text"
|
||||||
|
payload, _ = sjson.SetBytes(payload, path, obfuscated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
47
internal/runtime/executor/cloak_utils.go
Normal file
47
internal/runtime/executor/cloak_utils.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
|
||||||
|
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||||
|
|
||||||
|
// generateFakeUserID generates a fake user ID in Claude Code format.
|
||||||
|
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
|
||||||
|
func generateFakeUserID() string {
|
||||||
|
hexBytes := make([]byte, 32)
|
||||||
|
_, _ = rand.Read(hexBytes)
|
||||||
|
hexPart := hex.EncodeToString(hexBytes)
|
||||||
|
uuidPart := uuid.New().String()
|
||||||
|
return "user_" + hexPart + "_account__session_" + uuidPart
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidUserID checks if a user ID matches Claude Code format.
|
||||||
|
func isValidUserID(userID string) bool {
|
||||||
|
return userIDPattern.MatchString(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
|
||||||
|
// Returns true if cloaking should be applied.
|
||||||
|
func shouldCloak(cloakMode string, userAgent string) bool {
|
||||||
|
switch strings.ToLower(cloakMode) {
|
||||||
|
case "always":
|
||||||
|
return true
|
||||||
|
case "never":
|
||||||
|
return false
|
||||||
|
default: // "auto" or empty
|
||||||
|
// If client is Claude Code, don't cloak
|
||||||
|
return !strings.HasPrefix(userAgent, "claude-cli")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClaudeCodeClient checks if the User-Agent indicates a Claude Code client.
|
||||||
|
func isClaudeCodeClient(userAgent string) bool {
|
||||||
|
return strings.HasPrefix(userAgent, "claude-cli")
|
||||||
|
}
|
||||||
@@ -96,12 +96,13 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||||
body = misc.StripCodexUserAgent(body)
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
@@ -149,7 +150,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -208,12 +209,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, true)
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, true)
|
||||||
body = misc.StripCodexUserAgent(body)
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
@@ -263,7 +265,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return nil, readErr
|
return nil, readErr
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -316,7 +318,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||||
body = misc.StripCodexUserAgent(body)
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -123,13 +123,14 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String())
|
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -226,7 +227,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
if httpResp.StatusCode == 429 {
|
if httpResp.StatusCode == 429 {
|
||||||
if idx+1 < len(models) {
|
if idx+1 < len(models) {
|
||||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||||
@@ -272,13 +273,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String())
|
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||||
|
|
||||||
projectID := resolveGeminiProjectID(auth)
|
projectID := resolveGeminiProjectID(auth)
|
||||||
|
|
||||||
@@ -358,7 +360,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
if httpResp.StatusCode == 429 {
|
if httpResp.StatusCode == 429 {
|
||||||
if idx+1 < len(models) {
|
if idx+1 < len(models) {
|
||||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||||
@@ -479,7 +481,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
for range models {
|
for range models {
|
||||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,13 +120,14 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
@@ -187,7 +188,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -222,13 +223,14 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
@@ -280,7 +282,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("gemini executor: close response body error: %v", errClose)
|
log.Errorf("gemini executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -338,7 +340,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
@@ -400,7 +402,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
log.Debugf("request error, error status: %d, error body: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -31,6 +32,143 @@ const (
|
|||||||
vertexAPIVersion = "v1"
|
vertexAPIVersion = "v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isImagenModel checks if the model name is an Imagen image generation model.
|
||||||
|
// Imagen models use the :predict action instead of :generateContent.
|
||||||
|
func isImagenModel(model string) bool {
|
||||||
|
lowerModel := strings.ToLower(model)
|
||||||
|
return strings.Contains(lowerModel, "imagen")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getVertexAction returns the appropriate action for the given model.
|
||||||
|
// Imagen models use "predict", while Gemini models use "generateContent".
|
||||||
|
func getVertexAction(model string, isStream bool) string {
|
||||||
|
if isImagenModel(model) {
|
||||||
|
return "predict"
|
||||||
|
}
|
||||||
|
if isStream {
|
||||||
|
return "streamGenerateContent"
|
||||||
|
}
|
||||||
|
return "generateContent"
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertImagenToGeminiResponse converts Imagen API response to Gemini format
|
||||||
|
// so it can be processed by the standard translation pipeline.
|
||||||
|
// This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview.
|
||||||
|
func convertImagenToGeminiResponse(data []byte, model string) []byte {
|
||||||
|
predictions := gjson.GetBytes(data, "predictions")
|
||||||
|
if !predictions.Exists() || !predictions.IsArray() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Gemini-compatible response with inlineData
|
||||||
|
parts := make([]map[string]any, 0)
|
||||||
|
for _, pred := range predictions.Array() {
|
||||||
|
imageData := pred.Get("bytesBase64Encoded").String()
|
||||||
|
mimeType := pred.Get("mimeType").String()
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "image/png"
|
||||||
|
}
|
||||||
|
if imageData != "" {
|
||||||
|
parts = append(parts, map[string]any{
|
||||||
|
"inlineData": map[string]any{
|
||||||
|
"mimeType": mimeType,
|
||||||
|
"data": imageData,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate unique response ID using timestamp
|
||||||
|
responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
response := map[string]any{
|
||||||
|
"candidates": []map[string]any{{
|
||||||
|
"content": map[string]any{
|
||||||
|
"parts": parts,
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
}},
|
||||||
|
"responseId": responseId,
|
||||||
|
"modelVersion": model,
|
||||||
|
// Imagen API doesn't return token counts, set to 0 for tracking purposes
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": 0,
|
||||||
|
"candidatesTokenCount": 0,
|
||||||
|
"totalTokenCount": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToImagenRequest converts a Gemini-style request to Imagen API format.
|
||||||
|
// Imagen API uses a different structure: instances[].prompt instead of contents[].
|
||||||
|
func convertToImagenRequest(payload []byte) ([]byte, error) {
|
||||||
|
// Extract prompt from Gemini-style contents
|
||||||
|
prompt := ""
|
||||||
|
|
||||||
|
// Try to get prompt from contents[0].parts[0].text
|
||||||
|
contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text")
|
||||||
|
if contentsText.Exists() {
|
||||||
|
prompt = contentsText.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no contents, try messages format (OpenAI-compatible)
|
||||||
|
if prompt == "" {
|
||||||
|
messagesText := gjson.GetBytes(payload, "messages.#.content")
|
||||||
|
if messagesText.Exists() && messagesText.IsArray() {
|
||||||
|
for _, msg := range messagesText.Array() {
|
||||||
|
if msg.String() != "" {
|
||||||
|
prompt = msg.String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If still no prompt, try direct prompt field
|
||||||
|
if prompt == "" {
|
||||||
|
directPrompt := gjson.GetBytes(payload, "prompt")
|
||||||
|
if directPrompt.Exists() {
|
||||||
|
prompt = directPrompt.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt == "" {
|
||||||
|
return nil, fmt.Errorf("imagen: no prompt found in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Imagen API request
|
||||||
|
imagenReq := map[string]any{
|
||||||
|
"instances": []map[string]any{
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"sampleCount": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract optional parameters
|
||||||
|
if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() {
|
||||||
|
imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String()
|
||||||
|
}
|
||||||
|
if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() {
|
||||||
|
imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int())
|
||||||
|
}
|
||||||
|
if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() {
|
||||||
|
imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(imagenReq)
|
||||||
|
}
|
||||||
|
|
||||||
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
||||||
type GeminiVertexExecutor struct {
|
type GeminiVertexExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -160,26 +298,39 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
var body []byte
|
||||||
to := sdktranslator.FromString("gemini")
|
|
||||||
|
|
||||||
originalPayload := bytes.Clone(req.Payload)
|
// Handle Imagen models with special request format
|
||||||
if len(opts.OriginalRequest) > 0 {
|
if isImagenModel(baseModel) {
|
||||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
imagenBody, errImagen := convertToImagenRequest(req.Payload)
|
||||||
}
|
if errImagen != nil {
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
return resp, errImagen
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
}
|
||||||
|
body = imagenBody
|
||||||
|
} else {
|
||||||
|
// Standard Gemini translation flow
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
if err != nil {
|
if len(opts.OriginalRequest) > 0 {
|
||||||
return resp, err
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
action := getVertexAction(baseModel, false)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
|
||||||
|
|
||||||
action := "generateContent"
|
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
action = "countTokens"
|
action = "countTokens"
|
||||||
@@ -238,7 +389,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -249,6 +400,16 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
|
|
||||||
|
// For Imagen models, convert response to Gemini format before translation
|
||||||
|
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
|
||||||
|
if isImagenModel(baseModel) {
|
||||||
|
data = convertImagenToGeminiResponse(data, baseModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard Gemini translation (works for both Gemini and converted Imagen responses)
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
@@ -272,16 +433,17 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := getVertexAction(baseModel, false)
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
action = "countTokens"
|
action = "countTokens"
|
||||||
@@ -341,7 +503,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -375,21 +537,26 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, true)
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
|
||||||
if opts.Alt == "" {
|
// Imagen models don't support streaming, skip SSE params
|
||||||
url = url + "?alt=sse"
|
if !isImagenModel(baseModel) {
|
||||||
} else {
|
if opts.Alt == "" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
@@ -434,7 +601,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -494,24 +661,29 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, true)
|
||||||
// For API key auth, use simpler URL format without project/location
|
// For API key auth, use simpler URL format without project/location
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||||
if opts.Alt == "" {
|
// Imagen models don't support streaming, skip SSE params
|
||||||
url = url + "?alt=sse"
|
if !isImagenModel(baseModel) {
|
||||||
} else {
|
if opts.Alt == "" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
@@ -553,7 +725,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -605,7 +777,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
|
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
@@ -666,7 +838,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
@@ -689,7 +861,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
|
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
@@ -750,7 +922,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
|||||||
@@ -119,7 +119,8 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
body = e.normalizeModel(req.Model, body)
|
body = e.normalizeModel(req.Model, body)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", false)
|
body, _ = sjson.SetBytes(body, "stream", false)
|
||||||
|
|
||||||
path := githubCopilotChatPath
|
path := githubCopilotChatPath
|
||||||
@@ -218,7 +219,8 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
body = e.normalizeModel(req.Model, body)
|
body = e.normalizeModel(req.Model, body)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
// Enable stream options for usage stats in stream
|
// Enable stream options for usage stats in stream
|
||||||
if !useResponses {
|
if !useResponses {
|
||||||
|
|||||||
@@ -92,13 +92,14 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow")
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = preserveReasoningContentInMessages(body)
|
body = preserveReasoningContentInMessages(body)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
|
|
||||||
@@ -141,7 +142,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("iflow request error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -190,7 +191,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow")
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -201,7 +202,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||||
body = ensureToolsArray(body)
|
body = ensureToolsArray(body)
|
||||||
}
|
}
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
|
|
||||||
@@ -242,7 +244,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
log.Errorf("iflow executor: close response body error: %v", errClose)
|
log.Errorf("iflow executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
log.Debugf("iflow streaming error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -332,6 +335,12 @@ func summarizeErrorBody(contentType string, body []byte) string {
|
|||||||
}
|
}
|
||||||
return "[html body omitted]"
|
return "[html body omitted]"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to extract error message from JSON response
|
||||||
|
if message := extractJSONErrorMessage(body); message != "" {
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
return string(body)
|
return string(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -358,3 +367,25 @@ func extractHTMLTitle(body []byte) string {
|
|||||||
}
|
}
|
||||||
return strings.Join(strings.Fields(title), " ")
|
return strings.Join(strings.Fields(title), " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractJSONErrorMessage attempts to extract error.message from JSON error responses
|
||||||
|
func extractJSONErrorMessage(body []byte) string {
|
||||||
|
result := gjson.GetBytes(body, "error.message")
|
||||||
|
if result.Exists() && result.String() != "" {
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// logWithRequestID returns a logrus Entry with request_id field populated from context.
|
||||||
|
// If no request ID is found in context, it returns the standard logger.
|
||||||
|
func logWithRequestID(ctx context.Context) *log.Entry {
|
||||||
|
if ctx == nil {
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
requestID := logging.GetRequestID(ctx)
|
||||||
|
if requestID == "" {
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
return log.WithField("request_id", requestID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -90,9 +90,10 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -145,7 +146,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -185,9 +186,10 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
}
|
}
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -237,7 +239,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -297,7 +299,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
|||||||
|
|
||||||
modelForCounting := baseModel
|
modelForCounting := baseModel
|
||||||
|
|
||||||
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -12,8 +14,9 @@ import (
|
|||||||
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||||
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
||||||
// and restricts matches to the given protocol when supplied. Defaults are checked
|
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||||
// against the original payload when provided.
|
// against the original payload when provided. requestedModel carries the client-visible
|
||||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte) []byte {
|
// model name before alias resolution so payload rules can target aliases precisely.
|
||||||
|
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||||
if cfg == nil || len(payload) == 0 {
|
if cfg == nil || len(payload) == 0 {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
@@ -22,10 +25,11 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
model = strings.TrimSpace(model)
|
model = strings.TrimSpace(model)
|
||||||
if model == "" {
|
requestedModel = strings.TrimSpace(requestedModel)
|
||||||
|
if model == "" && requestedModel == "" {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
candidates := payloadModelCandidates(cfg, model, protocol)
|
candidates := payloadModelCandidates(model, requestedModel)
|
||||||
out := payload
|
out := payload
|
||||||
source := original
|
source := original
|
||||||
if len(source) == 0 {
|
if len(source) == 0 {
|
||||||
@@ -163,65 +167,42 @@ func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func payloadModelCandidates(cfg *config.Config, model, protocol string) []string {
|
func payloadModelCandidates(model, requestedModel string) []string {
|
||||||
model = strings.TrimSpace(model)
|
model = strings.TrimSpace(model)
|
||||||
if model == "" {
|
requestedModel = strings.TrimSpace(requestedModel)
|
||||||
|
if model == "" && requestedModel == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
candidates := []string{model}
|
candidates := make([]string, 0, 3)
|
||||||
if cfg == nil {
|
seen := make(map[string]struct{}, 3)
|
||||||
return candidates
|
addCandidate := func(value string) {
|
||||||
}
|
value = strings.TrimSpace(value)
|
||||||
aliases := payloadModelAliases(cfg, model, protocol)
|
if value == "" {
|
||||||
if len(aliases) == 0 {
|
return
|
||||||
return candidates
|
|
||||||
}
|
|
||||||
seen := map[string]struct{}{strings.ToLower(model): struct{}{}}
|
|
||||||
for _, alias := range aliases {
|
|
||||||
alias = strings.TrimSpace(alias)
|
|
||||||
if alias == "" {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
key := strings.ToLower(alias)
|
key := strings.ToLower(value)
|
||||||
if _, ok := seen[key]; ok {
|
if _, ok := seen[key]; ok {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
seen[key] = struct{}{}
|
seen[key] = struct{}{}
|
||||||
candidates = append(candidates, alias)
|
candidates = append(candidates, value)
|
||||||
|
}
|
||||||
|
if model != "" {
|
||||||
|
addCandidate(model)
|
||||||
|
}
|
||||||
|
if requestedModel != "" {
|
||||||
|
parsed := thinking.ParseSuffix(requestedModel)
|
||||||
|
base := strings.TrimSpace(parsed.ModelName)
|
||||||
|
if base != "" {
|
||||||
|
addCandidate(base)
|
||||||
|
}
|
||||||
|
if parsed.HasSuffix {
|
||||||
|
addCandidate(requestedModel)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return candidates
|
return candidates
|
||||||
}
|
}
|
||||||
|
|
||||||
func payloadModelAliases(cfg *config.Config, model, protocol string) []string {
|
|
||||||
if cfg == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
model = strings.TrimSpace(model)
|
|
||||||
if model == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
channel := strings.ToLower(strings.TrimSpace(protocol))
|
|
||||||
if channel == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
entries := cfg.OAuthModelAlias[channel]
|
|
||||||
if len(entries) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
aliases := make([]string, 0, 2)
|
|
||||||
for _, entry := range entries {
|
|
||||||
if !strings.EqualFold(strings.TrimSpace(entry.Name), model) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
alias := strings.TrimSpace(entry.Alias)
|
|
||||||
if alias == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
aliases = append(aliases, alias)
|
|
||||||
}
|
|
||||||
return aliases
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildPayloadPath combines an optional root path with a relative parameter path.
|
// buildPayloadPath combines an optional root path with a relative parameter path.
|
||||||
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
||||||
// the parameter path is treated as relative to root.
|
// the parameter path is treated as relative to root.
|
||||||
@@ -258,6 +239,35 @@ func payloadRawValue(value any) ([]byte, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||||
|
fallback = strings.TrimSpace(fallback)
|
||||||
|
if len(opts.Metadata) == 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(v) == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case []byte:
|
||||||
|
if len(v) == 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(string(v))
|
||||||
|
if trimmed == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return trimmed
|
||||||
|
default:
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
||||||
// Examples:
|
// Examples:
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -86,12 +86,13 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -132,7 +133,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -172,7 +173,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -184,7 +185,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||||
}
|
}
|
||||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -220,7 +222,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
|||||||
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
|
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
|
||||||
// - fromFormat: Source request format (e.g., openai, codex, gemini)
|
// - fromFormat: Source request format (e.g., openai, codex, gemini)
|
||||||
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
|
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
|
||||||
|
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - Modified request body JSON with thinking configuration applied
|
// - Modified request body JSON with thinking configuration applied
|
||||||
@@ -79,12 +80,16 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
|||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
// // With suffix - suffix config takes priority
|
// // With suffix - suffix config takes priority
|
||||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini")
|
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini", "gemini")
|
||||||
//
|
//
|
||||||
// // Without suffix - uses body config
|
// // Without suffix - uses body config
|
||||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini")
|
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini", "gemini")
|
||||||
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string) ([]byte, error) {
|
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string, providerKey string) ([]byte, error) {
|
||||||
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
|
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
|
||||||
|
providerKey = strings.ToLower(strings.TrimSpace(providerKey))
|
||||||
|
if providerKey == "" {
|
||||||
|
providerKey = providerFormat
|
||||||
|
}
|
||||||
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
|
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
|
||||||
if fromFormat == "" {
|
if fromFormat == "" {
|
||||||
fromFormat = providerFormat
|
fromFormat = providerFormat
|
||||||
@@ -102,7 +107,8 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string
|
|||||||
// 2. Parse suffix and get modelInfo
|
// 2. Parse suffix and get modelInfo
|
||||||
suffixResult := ParseSuffix(model)
|
suffixResult := ParseSuffix(model)
|
||||||
baseModel := suffixResult.ModelName
|
baseModel := suffixResult.ModelName
|
||||||
modelInfo := registry.LookupModelInfo(baseModel)
|
// Use provider-specific lookup to handle capability differences across providers.
|
||||||
|
modelInfo := registry.LookupModelInfo(baseModel, providerKey)
|
||||||
|
|
||||||
// 3. Model capability check
|
// 3. Model capability check
|
||||||
// Unknown models are treated as user-defined so thinking config can still be applied.
|
// Unknown models are treated as user-defined so thinking config can still be applied.
|
||||||
|
|||||||
@@ -80,9 +80,66 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||||
|
|
||||||
|
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
|
||||||
|
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
|
||||||
|
// Anthropic API requires this constraint; violating it returns a 400 error.
|
||||||
|
func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo *registry.ModelInfo) []byte {
|
||||||
|
if budgetTokens <= 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the request satisfies Claude constraints:
|
||||||
|
// 1) Determine effective max_tokens (request overrides model default)
|
||||||
|
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
|
||||||
|
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
|
||||||
|
// 4) If max_tokens came from model default, write it back into the request
|
||||||
|
|
||||||
|
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
|
||||||
|
if setDefaultMax && effectiveMax > 0 {
|
||||||
|
body, _ = sjson.SetBytes(body, "max_tokens", effectiveMax)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the budget we would apply after enforcing budget_tokens < max_tokens.
|
||||||
|
adjustedBudget := budgetTokens
|
||||||
|
if effectiveMax > 0 && adjustedBudget >= effectiveMax {
|
||||||
|
adjustedBudget = effectiveMax - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
minBudget := 0
|
||||||
|
if modelInfo != nil && modelInfo.Thinking != nil {
|
||||||
|
minBudget = modelInfo.Thinking.Min
|
||||||
|
}
|
||||||
|
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
|
||||||
|
// If enforcing the max_tokens constraint would push the budget below the model minimum,
|
||||||
|
// leave the request unchanged.
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if adjustedBudget != budgetTokens {
|
||||||
|
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", adjustedBudget)
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// effectiveMaxTokens returns the max tokens to cap thinking:
|
||||||
|
// prefer request-provided max_tokens; otherwise fall back to model default.
|
||||||
|
// The boolean indicates whether the value came from the model default (and thus should be written back).
|
||||||
|
func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
|
||||||
|
if maxTok := gjson.GetBytes(body, "max_tokens"); maxTok.Exists() && maxTok.Int() > 0 {
|
||||||
|
return int(maxTok.Int()), false
|
||||||
|
}
|
||||||
|
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
|
||||||
|
return modelInfo.MaxCompletionTokens, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
return body, nil
|
return body, nil
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ package claude
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
@@ -19,29 +17,6 @@ import (
|
|||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// deriveSessionID generates a stable session ID from the request.
|
|
||||||
// Uses the hash of the first user message to identify the conversation.
|
|
||||||
func deriveSessionID(rawJSON []byte) string {
|
|
||||||
messages := gjson.GetBytes(rawJSON, "messages")
|
|
||||||
if !messages.IsArray() {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
for _, msg := range messages.Array() {
|
|
||||||
if msg.Get("role").String() == "user" {
|
|
||||||
content := msg.Get("content").String()
|
|
||||||
if content == "" {
|
|
||||||
// Try to get text from content array
|
|
||||||
content = msg.Get("content.0.text").String()
|
|
||||||
}
|
|
||||||
if content != "" {
|
|
||||||
h := sha256.Sum256([]byte(content))
|
|
||||||
return hex.EncodeToString(h[:16])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
|
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
|
||||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||||
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
|
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
|
||||||
@@ -61,11 +36,9 @@ func deriveSessionID(rawJSON []byte) string {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []byte: The transformed request data in Gemini CLI API format
|
// - []byte: The transformed request data in Gemini CLI API format
|
||||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||||
|
enableThoughtTranslate := true
|
||||||
rawJSON := bytes.Clone(inputRawJSON)
|
rawJSON := bytes.Clone(inputRawJSON)
|
||||||
|
|
||||||
// Derive session ID for signature caching
|
|
||||||
sessionID := deriveSessionID(rawJSON)
|
|
||||||
|
|
||||||
// system instruction
|
// system instruction
|
||||||
systemInstructionJSON := ""
|
systemInstructionJSON := ""
|
||||||
hasSystemInstruction := false
|
hasSystemInstruction := false
|
||||||
@@ -124,41 +97,49 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
||||||
// Use GetThinkingText to handle wrapped thinking objects
|
// Use GetThinkingText to handle wrapped thinking objects
|
||||||
thinkingText := thinking.GetThinkingText(contentResult)
|
thinkingText := thinking.GetThinkingText(contentResult)
|
||||||
signatureResult := contentResult.Get("signature")
|
|
||||||
clientSignature := ""
|
|
||||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
|
||||||
clientSignature = signatureResult.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always try cached signature first (more reliable than client-provided)
|
// Always try cached signature first (more reliable than client-provided)
|
||||||
// Client may send stale or invalid signatures from different sessions
|
// Client may send stale or invalid signatures from different sessions
|
||||||
signature := ""
|
signature := ""
|
||||||
if sessionID != "" && thinkingText != "" {
|
if thinkingText != "" {
|
||||||
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
|
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
|
||||||
signature = cachedSig
|
signature = cachedSig
|
||||||
// log.Debugf("Using cached signature for thinking block")
|
// log.Debugf("Using cached signature for thinking block")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to client signature only if cache miss and client signature is valid
|
// Fallback to client signature only if cache miss and client signature is valid
|
||||||
if signature == "" && cache.HasValidSignature(clientSignature) {
|
if signature == "" {
|
||||||
signature = clientSignature
|
signatureResult := contentResult.Get("signature")
|
||||||
|
clientSignature := ""
|
||||||
|
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||||
|
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
|
||||||
|
if len(arrayClientSignatures) == 2 {
|
||||||
|
if modelName == arrayClientSignatures[0] {
|
||||||
|
clientSignature = arrayClientSignatures[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cache.HasValidSignature(modelName, clientSignature) {
|
||||||
|
signature = clientSignature
|
||||||
|
}
|
||||||
// log.Debugf("Using client-provided signature for thinking block")
|
// log.Debugf("Using client-provided signature for thinking block")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store for subsequent tool_use in the same message
|
// Store for subsequent tool_use in the same message
|
||||||
if cache.HasValidSignature(signature) {
|
if cache.HasValidSignature(modelName, signature) {
|
||||||
currentMessageThinkingSignature = signature
|
currentMessageThinkingSignature = signature
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip trailing unsigned thinking blocks on last assistant message
|
// Skip trailing unsigned thinking blocks on last assistant message
|
||||||
isUnsigned := !cache.HasValidSignature(signature)
|
isUnsigned := !cache.HasValidSignature(modelName, signature)
|
||||||
|
|
||||||
// If unsigned, skip entirely (don't convert to text)
|
// If unsigned, skip entirely (don't convert to text)
|
||||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||||
// Converting to text would break this requirement
|
// Converting to text would break this requirement
|
||||||
if isUnsigned {
|
if isUnsigned {
|
||||||
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||||
|
enableThoughtTranslate = false
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,7 +187,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||||
// and also works for Claude through Antigravity API
|
// and also works for Claude through Antigravity API
|
||||||
const skipSentinel = "skip_thought_signature_validator"
|
const skipSentinel = "skip_thought_signature_validator"
|
||||||
if cache.HasValidSignature(currentMessageThinkingSignature) {
|
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
|
||||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||||
} else {
|
} else {
|
||||||
// No valid signature - use skip sentinel to bypass validation
|
// No valid signature - use skip sentinel to bypass validation
|
||||||
@@ -386,7 +367,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||||
if t.Get("type").String() == "enabled" {
|
if t.Get("type").String() == "enabled" {
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -73,30 +74,41 @@ func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Valid signature must be at least 50 characters
|
// Valid signature must be at least 50 characters
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Let me think..."
|
||||||
|
|
||||||
|
// Pre-cache the signature (simulating a previous response for the same thinking text)
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||||
{"type": "text", "text": "Answer"}
|
{"type": "text", "text": "Answer"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Check thinking block conversion
|
// Check thinking block conversion (now in contents.1 due to user message)
|
||||||
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
|
firstPart := gjson.Get(outputStr, "request.contents.1.parts.0")
|
||||||
if !firstPart.Get("thought").Bool() {
|
if !firstPart.Get("thought").Bool() {
|
||||||
t.Error("thinking block should have thought: true")
|
t.Error("thinking block should have thought: true")
|
||||||
}
|
}
|
||||||
if firstPart.Get("text").String() != "Let me think..." {
|
if firstPart.Get("text").String() != thinkingText {
|
||||||
t.Error("thinking text mismatch")
|
t.Error("thinking text mismatch")
|
||||||
}
|
}
|
||||||
if firstPart.Get("thoughtSignature").String() != validSignature {
|
if firstPart.Get("thoughtSignature").String() != validSignature {
|
||||||
@@ -105,6 +117,8 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Unsigned thinking blocks should be removed entirely (not converted to text)
|
// Unsigned thinking blocks should be removed entirely (not converted to text)
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
@@ -226,14 +240,22 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Let me think..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||||
{
|
{
|
||||||
"type": "tool_use",
|
"type": "tool_use",
|
||||||
"id": "call_123",
|
"id": "call_123",
|
||||||
@@ -245,11 +267,13 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
|||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Check function call has the signature from the preceding thinking block
|
// Check function call has the signature from the preceding thinking block (now in contents.1)
|
||||||
part := gjson.Get(outputStr, "request.contents.0.parts.1")
|
part := gjson.Get(outputStr, "request.contents.1.parts.1")
|
||||||
if part.Get("functionCall.name").String() != "get_weather" {
|
if part.Get("functionCall.name").String() != "get_weather" {
|
||||||
t.Errorf("Expected functionCall, got %s", part.Raw)
|
t.Errorf("Expected functionCall, got %s", part.Raw)
|
||||||
}
|
}
|
||||||
@@ -259,26 +283,36 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Case: text block followed by thinking block -> should be reordered to thinking first
|
// Case: text block followed by thinking block -> should be reordered to thinking first
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Planning..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "Here is the plan."},
|
{"type": "text", "text": "Here is the plan."},
|
||||||
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Verify order: Thinking block MUST be first
|
// Verify order: Thinking block MUST be first (now in contents.1 due to user message)
|
||||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||||
}
|
}
|
||||||
@@ -459,7 +493,12 @@ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Last assistant message ends with signed thinking block - should be kept
|
// Last assistant message ends with signed thinking block - should be kept
|
||||||
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Valid thinking..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -471,12 +510,14 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin
|
|||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "Here is my answer"},
|
{"type": "text", "text": "Here is my answer"},
|
||||||
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ type Params struct {
|
|||||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||||
|
|
||||||
// Signature caching support
|
// Signature caching support
|
||||||
SessionID string // Session ID derived from request for signature caching
|
|
||||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,9 +69,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
HasFirstResponse: false,
|
HasFirstResponse: false,
|
||||||
ResponseType: 0,
|
ResponseType: 0,
|
||||||
ResponseIndex: 0,
|
ResponseIndex: 0,
|
||||||
SessionID: deriveSessionID(originalRequestRawJSON),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||||
|
|
||||||
params := (*param).(*Params)
|
params := (*param).(*Params)
|
||||||
|
|
||||||
@@ -138,14 +137,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||||
// log.Debug("Branch: signature_delta")
|
// log.Debug("Branch: signature_delta")
|
||||||
|
|
||||||
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
|
if params.CurrentThinkingText.Len() > 0 {
|
||||||
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
|
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||||
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
// log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len())
|
||||||
params.CurrentThinkingText.Reset()
|
params.CurrentThinkingText.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
output = output + "event: content_block_delta\n"
|
output = output + "event: content_block_delta\n"
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
params.HasContent = true
|
params.HasContent = true
|
||||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||||
@@ -372,7 +371,7 @@ func resolveStopReason(params *Params) string {
|
|||||||
// - string: A Claude-compatible JSON response.
|
// - string: A Claude-compatible JSON response.
|
||||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||||
_ = originalRequestRawJSON
|
_ = originalRequestRawJSON
|
||||||
_ = requestRawJSON
|
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||||
|
|
||||||
root := gjson.ParseBytes(rawJSON)
|
root := gjson.ParseBytes(rawJSON)
|
||||||
promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
|
promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
|
||||||
@@ -437,7 +436,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
block := `{"type":"thinking","thinking":""}`
|
block := `{"type":"thinking","thinking":""}`
|
||||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||||
if thinkingSignature != "" {
|
if thinkingSignature != "" {
|
||||||
block, _ = sjson.Set(block, "signature", thinkingSignature)
|
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||||
}
|
}
|
||||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||||
thinkingBuilder.Reset()
|
thinkingBuilder.Reset()
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ import (
|
|||||||
// Signature Caching Tests
|
// Signature Caching Tests
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
|
func TestConvertAntigravityResponseToClaude_ParamsInitialized(t *testing.T) {
|
||||||
cache.ClearSignatureCache("")
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Request with user message - should derive session ID
|
// Request with user message - should initialize params
|
||||||
requestJSON := []byte(`{
|
requestJSON := []byte(`{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
|
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
|
||||||
@@ -37,10 +37,12 @@ func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m)
|
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m)
|
||||||
|
|
||||||
// Verify session ID was set
|
|
||||||
params := param.(*Params)
|
params := param.(*Params)
|
||||||
if params.SessionID == "" {
|
if !params.HasFirstResponse {
|
||||||
t.Error("SessionID should be derived from request")
|
t.Error("HasFirstResponse should be set after first chunk")
|
||||||
|
}
|
||||||
|
if params.CurrentThinkingText.Len() == 0 {
|
||||||
|
t.Error("Thinking text should be accumulated")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,6 +99,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
|||||||
cache.ClearSignatureCache("")
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
requestJSON := []byte(`{
|
requestJSON := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
|
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
@@ -129,12 +132,8 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
|||||||
// Process thinking chunk
|
// Process thinking chunk
|
||||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m)
|
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m)
|
||||||
params := param.(*Params)
|
params := param.(*Params)
|
||||||
sessionID := params.SessionID
|
|
||||||
thinkingText := params.CurrentThinkingText.String()
|
thinkingText := params.CurrentThinkingText.String()
|
||||||
|
|
||||||
if sessionID == "" {
|
|
||||||
t.Fatal("SessionID should be set")
|
|
||||||
}
|
|
||||||
if thinkingText == "" {
|
if thinkingText == "" {
|
||||||
t.Fatal("Thinking text should be accumulated")
|
t.Fatal("Thinking text should be accumulated")
|
||||||
}
|
}
|
||||||
@@ -143,7 +142,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
|||||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m)
|
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m)
|
||||||
|
|
||||||
// Verify signature was cached
|
// Verify signature was cached
|
||||||
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
|
cachedSig := cache.GetCachedSignature("claude-sonnet-4-5-thinking", thinkingText)
|
||||||
if cachedSig != validSignature {
|
if cachedSig != validSignature {
|
||||||
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
|
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
|
||||||
}
|
}
|
||||||
@@ -158,6 +157,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
|||||||
cache.ClearSignatureCache("")
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
requestJSON := []byte(`{
|
requestJSON := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
|
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
@@ -221,13 +221,12 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
|||||||
// Process first thinking block
|
// Process first thinking block
|
||||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m)
|
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m)
|
||||||
params := param.(*Params)
|
params := param.(*Params)
|
||||||
sessionID := params.SessionID
|
|
||||||
firstThinkingText := params.CurrentThinkingText.String()
|
firstThinkingText := params.CurrentThinkingText.String()
|
||||||
|
|
||||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m)
|
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m)
|
||||||
|
|
||||||
// Verify first signature cached
|
// Verify first signature cached
|
||||||
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
|
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", firstThinkingText) != validSig1 {
|
||||||
t.Error("First thinking block signature should be cached")
|
t.Error("First thinking block signature should be cached")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,76 +240,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
|||||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m)
|
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m)
|
||||||
|
|
||||||
// Verify second signature cached
|
// Verify second signature cached
|
||||||
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
|
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", secondThinkingText) != validSig2 {
|
||||||
t.Error("Second thinking block signature should be cached")
|
t.Error("Second thinking block signature should be cached")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeriveSessionIDFromRequest(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input []byte
|
|
||||||
wantEmpty bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid user message",
|
|
||||||
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
|
|
||||||
wantEmpty: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user message with content array",
|
|
||||||
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
|
|
||||||
wantEmpty: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no user message",
|
|
||||||
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
|
|
||||||
wantEmpty: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty messages",
|
|
||||||
input: []byte(`{"messages": []}`),
|
|
||||||
wantEmpty: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no messages field",
|
|
||||||
input: []byte(`{}`),
|
|
||||||
wantEmpty: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := deriveSessionID(tt.input)
|
|
||||||
if tt.wantEmpty && result != "" {
|
|
||||||
t.Errorf("Expected empty session ID, got '%s'", result)
|
|
||||||
}
|
|
||||||
if !tt.wantEmpty && result == "" {
|
|
||||||
t.Error("Expected non-empty session ID")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
|
|
||||||
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
|
|
||||||
|
|
||||||
id1 := deriveSessionID(input)
|
|
||||||
id2 := deriveSessionID(input)
|
|
||||||
|
|
||||||
if id1 != id2 {
|
|
||||||
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
|
|
||||||
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
|
|
||||||
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
|
|
||||||
|
|
||||||
id1 := deriveSessionID(input1)
|
|
||||||
id2 := deriveSessionID(input2)
|
|
||||||
|
|
||||||
if id1 == id2 {
|
|
||||||
t.Error("Different messages should produce different session IDs")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ package gemini
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
@@ -32,12 +33,12 @@ import (
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - []byte: The transformed request data in Gemini API format
|
// - []byte: The transformed request data in Gemini API format
|
||||||
func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []byte {
|
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||||
rawJSON := bytes.Clone(inputRawJSON)
|
rawJSON := bytes.Clone(inputRawJSON)
|
||||||
template := ""
|
template := ""
|
||||||
template = `{"project":"","request":{},"model":""}`
|
template = `{"project":"","request":{},"model":""}`
|
||||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
template, _ = sjson.Set(template, "model", modelName)
|
||||||
template, _ = sjson.Delete(template, "request.model")
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
|
|
||||||
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
||||||
@@ -97,37 +98,40 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
|
// Gemini-specific handling for non-Claude models:
|
||||||
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
|
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
|
||||||
const skipSentinel = "skip_thought_signature_validator"
|
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
|
||||||
|
if !strings.Contains(modelName, "claude") {
|
||||||
|
const skipSentinel = "skip_thought_signature_validator"
|
||||||
|
|
||||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
||||||
if content.Get("role").String() == "model" {
|
if content.Get("role").String() == "model" {
|
||||||
// First pass: collect indices of thinking parts to remove
|
// First pass: collect indices of thinking parts to mark with skip sentinel
|
||||||
var thinkingIndicesToRemove []int64
|
var thinkingIndicesToSkipSignature []int64
|
||||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||||
// Mark thinking blocks for removal
|
// Collect indices of thinking blocks to mark with skip sentinel
|
||||||
if part.Get("thought").Bool() {
|
if part.Get("thought").Bool() {
|
||||||
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
|
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
|
||||||
}
|
|
||||||
// Add skip sentinel to functionCall parts
|
|
||||||
if part.Get("functionCall").Exists() {
|
|
||||||
existingSig := part.Get("thoughtSignature").String()
|
|
||||||
if existingSig == "" || len(existingSig) < 50 {
|
|
||||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
|
||||||
}
|
}
|
||||||
}
|
// Add skip sentinel to functionCall parts
|
||||||
return true
|
if part.Get("functionCall").Exists() {
|
||||||
})
|
existingSig := part.Get("thoughtSignature").String()
|
||||||
|
if existingSig == "" || len(existingSig) < 50 {
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
// Remove thinking blocks in reverse order to preserve indices
|
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
|
||||||
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
|
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
|
||||||
idx := thinkingIndicesToRemove[i]
|
idx := thinkingIndicesToSkipSignature[i]
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
|
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
return true
|
})
|
||||||
})
|
}
|
||||||
|
|
||||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,40 +62,6 @@ func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *test
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) {
|
|
||||||
// Thinking blocks should be removed entirely for Gemini
|
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
|
||||||
inputJSON := []byte(fmt.Sprintf(`{
|
|
||||||
"model": "gemini-3-pro-preview",
|
|
||||||
"contents": [
|
|
||||||
{
|
|
||||||
"role": "model",
|
|
||||||
"parts": [
|
|
||||||
{"thought": true, "text": "Thinking...", "thoughtSignature": "%s"},
|
|
||||||
{"text": "Here is my response"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`, validSignature))
|
|
||||||
|
|
||||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
|
||||||
outputStr := string(output)
|
|
||||||
|
|
||||||
// Check that thinking block is removed
|
|
||||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
|
||||||
if len(parts) != 1 {
|
|
||||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only text part should remain
|
|
||||||
if parts[0].Get("thought").Bool() {
|
|
||||||
t.Error("Thinking block should be removed for Gemini")
|
|
||||||
}
|
|
||||||
if parts[0].Get("text").String() != "Here is my response" {
|
|
||||||
t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
||||||
// Multiple functionCalls should all get skip_thought_signature_validator
|
// Multiple functionCalls should all get skip_thought_signature_validator
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
|
|||||||
@@ -66,6 +66,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Candidate count (OpenAI 'n' parameter)
|
||||||
|
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
|
||||||
|
if val := n.Int(); val > 1 {
|
||||||
|
out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
||||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||||
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
||||||
@@ -298,12 +305,12 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
|
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
|
||||||
tools := gjson.GetBytes(rawJSON, "tools")
|
tools := gjson.GetBytes(rawJSON, "tools")
|
||||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||||
toolNode := []byte(`{}`)
|
functionToolNode := []byte(`{}`)
|
||||||
hasTool := false
|
|
||||||
hasFunction := false
|
hasFunction := false
|
||||||
|
googleSearchNodes := make([][]byte, 0)
|
||||||
for _, t := range tools.Array() {
|
for _, t := range tools.Array() {
|
||||||
if t.Get("type").String() == "function" {
|
if t.Get("type").String() == "function" {
|
||||||
fn := t.Get("function")
|
fn := t.Get("function")
|
||||||
@@ -342,31 +349,37 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||||
if !hasFunction {
|
if !hasFunction {
|
||||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||||
}
|
}
|
||||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
toolNode = tmp
|
functionToolNode = tmp
|
||||||
hasFunction = true
|
hasFunction = true
|
||||||
hasTool = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if gs := t.Get("google_search"); gs.Exists() {
|
if gs := t.Get("google_search"); gs.Exists() {
|
||||||
|
googleToolNode := []byte(`{}`)
|
||||||
var errSet error
|
var errSet error
|
||||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
hasTool = true
|
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasTool {
|
if hasFunction || len(googleSearchNodes) > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
|
toolsNode := []byte("[]")
|
||||||
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
|
if hasFunction {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||||
|
}
|
||||||
|
for _, googleNode := range googleSearchNodes {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -98,9 +98,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
// Temperature setting for controlling response randomness
|
// Temperature setting for controlling response randomness
|
||||||
if temp := genConfig.Get("temperature"); temp.Exists() {
|
if temp := genConfig.Get("temperature"); temp.Exists() {
|
||||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||||
}
|
} else if topP := genConfig.Get("topP"); topP.Exists() {
|
||||||
// Top P setting for nucleus sampling
|
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||||
if topP := genConfig.Get("topP"); topP.Exists() {
|
|
||||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||||
}
|
}
|
||||||
// Stop sequences configuration for custom termination conditions
|
// Stop sequences configuration for custom termination conditions
|
||||||
|
|||||||
@@ -110,10 +110,8 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
// Temperature setting for controlling response randomness
|
// Temperature setting for controlling response randomness
|
||||||
if temp := root.Get("temperature"); temp.Exists() {
|
if temp := root.Get("temperature"); temp.Exists() {
|
||||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||||
}
|
} else if topP := root.Get("top_p"); topP.Exists() {
|
||||||
|
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||||
// Top P setting for nucleus sampling
|
|
||||||
if topP := root.Get("top_p"); topP.Exists() {
|
|
||||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -117,8 +117,12 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "usage.input_tokens", rootResult.Get("response.usage.input_tokens").Int())
|
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
|
||||||
template, _ = sjson.Set(template, "usage.output_tokens", rootResult.Get("response.usage.output_tokens").Int())
|
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens)
|
||||||
|
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
output = "event: message_delta\n"
|
output = "event: message_delta\n"
|
||||||
output += fmt.Sprintf("data: %s\n\n", template)
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
@@ -204,8 +208,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
|||||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
|
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
|
||||||
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
|
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
|
||||||
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
|
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
|
||||||
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
|
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||||
|
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
hasToolCall := false
|
hasToolCall := false
|
||||||
|
|
||||||
@@ -308,12 +316,27 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
|||||||
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
|
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() {
|
return out
|
||||||
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
|
}
|
||||||
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
|
|
||||||
|
func extractResponsesUsage(usage gjson.Result) (int64, int64, int64) {
|
||||||
|
if !usage.Exists() || usage.Type == gjson.Null {
|
||||||
|
return 0, 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
inputTokens := usage.Get("input_tokens").Int()
|
||||||
|
outputTokens := usage.Get("output_tokens").Int()
|
||||||
|
cachedTokens := usage.Get("input_tokens_details.cached_tokens").Int()
|
||||||
|
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
if inputTokens >= cachedTokens {
|
||||||
|
inputTokens -= cachedTokens
|
||||||
|
} else {
|
||||||
|
inputTokens = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputTokens, outputTokens, cachedTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.
|
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.
|
||||||
|
|||||||
@@ -63,6 +63,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Candidate count (OpenAI 'n' parameter)
|
||||||
|
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
|
||||||
|
if val := n.Int(); val > 1 {
|
||||||
|
out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
||||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||||
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
||||||
@@ -276,12 +283,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
|
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
|
||||||
tools := gjson.GetBytes(rawJSON, "tools")
|
tools := gjson.GetBytes(rawJSON, "tools")
|
||||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||||
toolNode := []byte(`{}`)
|
functionToolNode := []byte(`{}`)
|
||||||
hasTool := false
|
|
||||||
hasFunction := false
|
hasFunction := false
|
||||||
|
googleSearchNodes := make([][]byte, 0)
|
||||||
for _, t := range tools.Array() {
|
for _, t := range tools.Array() {
|
||||||
if t.Get("type").String() == "function" {
|
if t.Get("type").String() == "function" {
|
||||||
fn := t.Get("function")
|
fn := t.Get("function")
|
||||||
@@ -320,31 +327,37 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
}
|
}
|
||||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||||
if !hasFunction {
|
if !hasFunction {
|
||||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||||
}
|
}
|
||||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
toolNode = tmp
|
functionToolNode = tmp
|
||||||
hasFunction = true
|
hasFunction = true
|
||||||
hasTool = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if gs := t.Get("google_search"); gs.Exists() {
|
if gs := t.Get("google_search"); gs.Exists() {
|
||||||
|
googleToolNode := []byte(`{}`)
|
||||||
var errSet error
|
var errSet error
|
||||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
hasTool = true
|
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasTool {
|
if hasFunction || len(googleSearchNodes) > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
|
toolsNode := []byte("[]")
|
||||||
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
|
if hasFunction {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||||
|
}
|
||||||
|
for _, googleNode := range googleSearchNodes {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,13 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
out, _ = sjson.SetBytes(out, "generationConfig.topK", tkr.Num)
|
out, _ = sjson.SetBytes(out, "generationConfig.topK", tkr.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Candidate count (OpenAI 'n' parameter)
|
||||||
|
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
|
||||||
|
if val := n.Int(); val > 1 {
|
||||||
|
out, _ = sjson.SetBytes(out, "generationConfig.candidateCount", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Map OpenAI modalities -> Gemini generationConfig.responseModalities
|
// Map OpenAI modalities -> Gemini generationConfig.responseModalities
|
||||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||||
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
||||||
@@ -282,12 +289,12 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tools -> tools[0].functionDeclarations + tools[0].googleSearch passthrough
|
// tools -> tools[].functionDeclarations + tools[].googleSearch passthrough
|
||||||
tools := gjson.GetBytes(rawJSON, "tools")
|
tools := gjson.GetBytes(rawJSON, "tools")
|
||||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||||
toolNode := []byte(`{}`)
|
functionToolNode := []byte(`{}`)
|
||||||
hasTool := false
|
|
||||||
hasFunction := false
|
hasFunction := false
|
||||||
|
googleSearchNodes := make([][]byte, 0)
|
||||||
for _, t := range tools.Array() {
|
for _, t := range tools.Array() {
|
||||||
if t.Get("type").String() == "function" {
|
if t.Get("type").String() == "function" {
|
||||||
fn := t.Get("function")
|
fn := t.Get("function")
|
||||||
@@ -326,31 +333,37 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
}
|
}
|
||||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||||
if !hasFunction {
|
if !hasFunction {
|
||||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||||
}
|
}
|
||||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
toolNode = tmp
|
functionToolNode = tmp
|
||||||
hasFunction = true
|
hasFunction = true
|
||||||
hasTool = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if gs := t.Get("google_search"); gs.Exists() {
|
if gs := t.Get("google_search"); gs.Exists() {
|
||||||
|
googleToolNode := []byte(`{}`)
|
||||||
var errSet error
|
var errSet error
|
||||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
hasTool = true
|
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasTool {
|
if hasFunction || len(googleSearchNodes) > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "tools", []byte("[]"))
|
toolsNode := []byte("[]")
|
||||||
out, _ = sjson.SetRawBytes(out, "tools.0", toolNode)
|
if hasFunction {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||||
|
}
|
||||||
|
for _, googleNode := range googleSearchNodes {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRawBytes(out, "tools", toolsNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ import (
|
|||||||
// convertGeminiResponseToOpenAIChatParams holds parameters for response conversion.
|
// convertGeminiResponseToOpenAIChatParams holds parameters for response conversion.
|
||||||
type convertGeminiResponseToOpenAIChatParams struct {
|
type convertGeminiResponseToOpenAIChatParams struct {
|
||||||
UnixTimestamp int64
|
UnixTimestamp int64
|
||||||
FunctionIndex int
|
// FunctionIndex tracks tool call indices per candidate index to support multiple candidates.
|
||||||
|
FunctionIndex map[int]int
|
||||||
}
|
}
|
||||||
|
|
||||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
@@ -42,13 +43,20 @@ var functionCallIDCounter uint64
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||||
func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||||
|
// Initialize parameters if nil.
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &convertGeminiResponseToOpenAIChatParams{
|
*param = &convertGeminiResponseToOpenAIChatParams{
|
||||||
UnixTimestamp: 0,
|
UnixTimestamp: 0,
|
||||||
FunctionIndex: 0,
|
FunctionIndex: make(map[int]int),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure the Map is initialized (handling cases where param might be reused from older context).
|
||||||
|
p := (*param).(*convertGeminiResponseToOpenAIChatParams)
|
||||||
|
if p.FunctionIndex == nil {
|
||||||
|
p.FunctionIndex = make(map[int]int)
|
||||||
|
}
|
||||||
|
|
||||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||||
}
|
}
|
||||||
@@ -57,151 +65,179 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the OpenAI SSE template.
|
// Initialize the OpenAI SSE base template.
|
||||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
// We use a base template and clone it for each candidate to support multiple candidates.
|
||||||
|
baseTemplate := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||||
|
|
||||||
// Extract and set the model version.
|
// Extract and set the model version.
|
||||||
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
|
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
baseTemplate, _ = sjson.Set(baseTemplate, "model", modelVersionResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set the creation timestamp.
|
// Extract and set the creation timestamp.
|
||||||
if createTimeResult := gjson.GetBytes(rawJSON, "createTime"); createTimeResult.Exists() {
|
if createTimeResult := gjson.GetBytes(rawJSON, "createTime"); createTimeResult.Exists() {
|
||||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
(*param).(*convertGeminiResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
|
p.UnixTimestamp = t.Unix()
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "created", (*param).(*convertGeminiResponseToOpenAIChatParams).UnixTimestamp)
|
baseTemplate, _ = sjson.Set(baseTemplate, "created", p.UnixTimestamp)
|
||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "created", (*param).(*convertGeminiResponseToOpenAIChatParams).UnixTimestamp)
|
baseTemplate, _ = sjson.Set(baseTemplate, "created", p.UnixTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set the response ID.
|
// Extract and set the response ID.
|
||||||
if responseIDResult := gjson.GetBytes(rawJSON, "responseId"); responseIDResult.Exists() {
|
if responseIDResult := gjson.GetBytes(rawJSON, "responseId"); responseIDResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
baseTemplate, _ = sjson.Set(baseTemplate, "id", responseIDResult.String())
|
||||||
}
|
|
||||||
|
|
||||||
// Extract and set the finish reason.
|
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
|
// Usage is applied to the base template so it appears in the chunks.
|
||||||
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
||||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
// Include cached token count if present (indicates prompt caching is working)
|
// Include cached token count if present (indicates prompt caching is working)
|
||||||
if cachedTokenCount > 0 {
|
if cachedTokenCount > 0 {
|
||||||
var err error
|
var err error
|
||||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
baseTemplate, err = sjson.Set(baseTemplate, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
|
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the main content part of the response.
|
var responseStrings []string
|
||||||
partsResult := gjson.GetBytes(rawJSON, "candidates.0.content.parts")
|
candidates := gjson.GetBytes(rawJSON, "candidates")
|
||||||
hasFunctionCall := false
|
|
||||||
if partsResult.IsArray() {
|
// Iterate over all candidates to support candidate_count > 1.
|
||||||
partResults := partsResult.Array()
|
if candidates.IsArray() {
|
||||||
for i := 0; i < len(partResults); i++ {
|
candidates.ForEach(func(_, candidate gjson.Result) bool {
|
||||||
partResult := partResults[i]
|
// Clone the template for the current candidate.
|
||||||
partTextResult := partResult.Get("text")
|
template := baseTemplate
|
||||||
functionCallResult := partResult.Get("functionCall")
|
|
||||||
inlineDataResult := partResult.Get("inlineData")
|
// Set the specific index for this candidate.
|
||||||
if !inlineDataResult.Exists() {
|
candidateIndex := int(candidate.Get("index").Int())
|
||||||
inlineDataResult = partResult.Get("inline_data")
|
template, _ = sjson.Set(template, "choices.0.index", candidateIndex)
|
||||||
}
|
|
||||||
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
// Extract and set the finish reason.
|
||||||
if !thoughtSignatureResult.Exists() {
|
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
|
||||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
partsResult := candidate.Get("content.parts")
|
||||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
hasFunctionCall := false
|
||||||
|
|
||||||
// Skip pure thoughtSignature parts but keep any actual payload in the same part.
|
if partsResult.IsArray() {
|
||||||
if hasThoughtSignature && !hasContentPayload {
|
partResults := partsResult.Array()
|
||||||
continue
|
for i := 0; i < len(partResults); i++ {
|
||||||
|
partResult := partResults[i]
|
||||||
|
partTextResult := partResult.Get("text")
|
||||||
|
functionCallResult := partResult.Get("functionCall")
|
||||||
|
inlineDataResult := partResult.Get("inlineData")
|
||||||
|
if !inlineDataResult.Exists() {
|
||||||
|
inlineDataResult = partResult.Get("inline_data")
|
||||||
|
}
|
||||||
|
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
||||||
|
if !thoughtSignatureResult.Exists() {
|
||||||
|
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||||
|
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||||
|
|
||||||
|
// Skip pure thoughtSignature parts but keep any actual payload in the same part.
|
||||||
|
if hasThoughtSignature && !hasContentPayload {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if partTextResult.Exists() {
|
||||||
|
text := partTextResult.String()
|
||||||
|
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||||
|
if partResult.Get("thought").Bool() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", text)
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.content", text)
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
} else if functionCallResult.Exists() {
|
||||||
|
// Handle function call content.
|
||||||
|
hasFunctionCall = true
|
||||||
|
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||||
|
|
||||||
|
// Retrieve the function index for this specific candidate.
|
||||||
|
functionCallIndex := p.FunctionIndex[candidateIndex]
|
||||||
|
p.FunctionIndex[candidateIndex]++
|
||||||
|
|
||||||
|
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
||||||
|
functionCallIndex = len(toolCallsResult.Array())
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||||
|
}
|
||||||
|
|
||||||
|
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||||
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||||
|
} else if inlineDataResult.Exists() {
|
||||||
|
data := inlineDataResult.Get("data").String()
|
||||||
|
if data == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mimeType := inlineDataResult.Get("mimeType").String()
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = inlineDataResult.Get("mime_type").String()
|
||||||
|
}
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "image/png"
|
||||||
|
}
|
||||||
|
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||||
|
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||||
|
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||||
|
}
|
||||||
|
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||||
|
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if partTextResult.Exists() {
|
if hasFunctionCall {
|
||||||
text := partTextResult.String()
|
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||||
if partResult.Get("thought").Bool() {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", text)
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.content", text)
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
|
||||||
} else if functionCallResult.Exists() {
|
|
||||||
// Handle function call content.
|
|
||||||
hasFunctionCall = true
|
|
||||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
|
||||||
functionCallIndex := (*param).(*convertGeminiResponseToOpenAIChatParams).FunctionIndex
|
|
||||||
(*param).(*convertGeminiResponseToOpenAIChatParams).FunctionIndex++
|
|
||||||
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
|
||||||
functionCallIndex = len(toolCallsResult.Array())
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
|
||||||
}
|
|
||||||
|
|
||||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
|
||||||
fcName := functionCallResult.Get("name").String()
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
|
||||||
} else if inlineDataResult.Exists() {
|
|
||||||
data := inlineDataResult.Get("data").String()
|
|
||||||
if data == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mimeType := inlineDataResult.Get("mimeType").String()
|
|
||||||
if mimeType == "" {
|
|
||||||
mimeType = inlineDataResult.Get("mime_type").String()
|
|
||||||
}
|
|
||||||
if mimeType == "" {
|
|
||||||
mimeType = "image/png"
|
|
||||||
}
|
|
||||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
|
||||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
|
||||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
|
||||||
}
|
|
||||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
|
||||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
responseStrings = append(responseStrings, template)
|
||||||
|
return true // continue loop
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// If there are no candidates (e.g., a pure usageMetadata chunk), return the usage chunk if present.
|
||||||
|
if gjson.GetBytes(rawJSON, "usageMetadata").Exists() && len(responseStrings) == 0 {
|
||||||
|
responseStrings = append(responseStrings, baseTemplate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasFunctionCall {
|
return responseStrings
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
|
||||||
}
|
|
||||||
|
|
||||||
return []string{template}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertGeminiResponseToOpenAINonStream converts a non-streaming Gemini response to a non-streaming OpenAI response.
|
// ConvertGeminiResponseToOpenAINonStream converts a non-streaming Gemini response to a non-streaming OpenAI response.
|
||||||
@@ -219,7 +255,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||||
func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||||
var unixTimestamp int64
|
var unixTimestamp int64
|
||||||
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
// Initialize template with an empty choices array to support multiple candidates.
|
||||||
|
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[]}`
|
||||||
|
|
||||||
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
|
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||||
}
|
}
|
||||||
@@ -238,11 +276,6 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
@@ -267,74 +300,96 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the main content part of the response.
|
// Process the main content part of the response for all candidates.
|
||||||
partsResult := gjson.GetBytes(rawJSON, "candidates.0.content.parts")
|
candidates := gjson.GetBytes(rawJSON, "candidates")
|
||||||
hasFunctionCall := false
|
if candidates.IsArray() {
|
||||||
if partsResult.IsArray() {
|
candidates.ForEach(func(_, candidate gjson.Result) bool {
|
||||||
partsResults := partsResult.Array()
|
// Construct a single Choice object.
|
||||||
for i := 0; i < len(partsResults); i++ {
|
choiceTemplate := `{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}`
|
||||||
partResult := partsResults[i]
|
|
||||||
partTextResult := partResult.Get("text")
|
// Set the index for this choice.
|
||||||
functionCallResult := partResult.Get("functionCall")
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "index", candidate.Get("index").Int())
|
||||||
inlineDataResult := partResult.Get("inlineData")
|
|
||||||
if !inlineDataResult.Exists() {
|
// Set finish reason.
|
||||||
inlineDataResult = partResult.Get("inline_data")
|
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if partTextResult.Exists() {
|
partsResult := candidate.Get("content.parts")
|
||||||
// Append text content, distinguishing between regular content and reasoning.
|
hasFunctionCall := false
|
||||||
if partResult.Get("thought").Bool() {
|
if partsResult.IsArray() {
|
||||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
|
partsResults := partsResult.Array()
|
||||||
} else {
|
for i := 0; i < len(partsResults); i++ {
|
||||||
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
|
partResult := partsResults[i]
|
||||||
}
|
partTextResult := partResult.Get("text")
|
||||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
functionCallResult := partResult.Get("functionCall")
|
||||||
} else if functionCallResult.Exists() {
|
inlineDataResult := partResult.Get("inlineData")
|
||||||
// Append function call content to the tool_calls array.
|
if !inlineDataResult.Exists() {
|
||||||
hasFunctionCall = true
|
inlineDataResult = partResult.Get("inline_data")
|
||||||
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
|
}
|
||||||
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
|
||||||
}
|
|
||||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
|
||||||
fcName := functionCallResult.Get("name").String()
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
|
||||||
} else if inlineDataResult.Exists() {
|
|
||||||
data := inlineDataResult.Get("data").String()
|
|
||||||
if data == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mimeType := inlineDataResult.Get("mimeType").String()
|
|
||||||
if mimeType == "" {
|
|
||||||
mimeType = inlineDataResult.Get("mime_type").String()
|
|
||||||
}
|
|
||||||
if mimeType == "" {
|
|
||||||
mimeType = "image/png"
|
|
||||||
}
|
|
||||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
|
||||||
imagesResult := gjson.Get(template, "choices.0.message.images")
|
|
||||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
|
|
||||||
}
|
|
||||||
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
|
|
||||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasFunctionCall {
|
if partTextResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
// Append text content, distinguishing between regular content and reasoning.
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
if partResult.Get("thought").Bool() {
|
||||||
|
oldVal := gjson.Get(choiceTemplate, "message.reasoning_content").String()
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.reasoning_content", oldVal+partTextResult.String())
|
||||||
|
} else {
|
||||||
|
oldVal := gjson.Get(choiceTemplate, "message.content").String()
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.content", oldVal+partTextResult.String())
|
||||||
|
}
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
|
||||||
|
} else if functionCallResult.Exists() {
|
||||||
|
// Append function call content to the tool_calls array.
|
||||||
|
hasFunctionCall = true
|
||||||
|
toolCallsResult := gjson.Get(choiceTemplate, "message.tool_calls")
|
||||||
|
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||||
|
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.tool_calls", `[]`)
|
||||||
|
}
|
||||||
|
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
|
}
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
|
||||||
|
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.tool_calls.-1", functionCallItemTemplate)
|
||||||
|
} else if inlineDataResult.Exists() {
|
||||||
|
data := inlineDataResult.Get("data").String()
|
||||||
|
if data != "" {
|
||||||
|
mimeType := inlineDataResult.Get("mimeType").String()
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = inlineDataResult.Get("mime_type").String()
|
||||||
|
}
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "image/png"
|
||||||
|
}
|
||||||
|
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||||
|
imagesResult := gjson.Get(choiceTemplate, "message.images")
|
||||||
|
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||||
|
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.images", `[]`)
|
||||||
|
}
|
||||||
|
imageIndex := len(gjson.Get(choiceTemplate, "message.images").Array())
|
||||||
|
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||||
|
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
|
||||||
|
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.images.-1", imagePayload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasFunctionCall {
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "finish_reason", "tool_calls")
|
||||||
|
choiceTemplate, _ = sjson.Set(choiceTemplate, "native_finish_reason", "tool_calls")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the constructed choice to the main choices array.
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.-1", choiceTemplate)
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|||||||
@@ -298,6 +298,15 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
}
|
}
|
||||||
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
|
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
|
||||||
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
|
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
|
||||||
|
|
||||||
|
case "reasoning":
|
||||||
|
thoughtContent := `{"role":"model","parts":[]}`
|
||||||
|
thought := `{"text":"","thoughtSignature":"","thought":true}`
|
||||||
|
thought, _ = sjson.Set(thought, "text", item.Get("summary.0.text").String())
|
||||||
|
thought, _ = sjson.Set(thought, "thoughtSignature", item.Get("encrypted_content").String())
|
||||||
|
|
||||||
|
thoughtContent, _ = sjson.SetRaw(thoughtContent, "parts.-1", thought)
|
||||||
|
out, _ = sjson.SetRaw(out, "contents.-1", thoughtContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if input.Exists() && input.Type == gjson.String {
|
} else if input.Exists() && input.Type == gjson.String {
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type geminiToResponsesState struct {
|
|||||||
|
|
||||||
// message aggregation
|
// message aggregation
|
||||||
MsgOpened bool
|
MsgOpened bool
|
||||||
|
MsgClosed bool
|
||||||
MsgIndex int
|
MsgIndex int
|
||||||
CurrentMsgID string
|
CurrentMsgID string
|
||||||
TextBuf strings.Builder
|
TextBuf strings.Builder
|
||||||
@@ -29,6 +30,7 @@ type geminiToResponsesState struct {
|
|||||||
ReasoningOpened bool
|
ReasoningOpened bool
|
||||||
ReasoningIndex int
|
ReasoningIndex int
|
||||||
ReasoningItemID string
|
ReasoningItemID string
|
||||||
|
ReasoningEnc string
|
||||||
ReasoningBuf strings.Builder
|
ReasoningBuf strings.Builder
|
||||||
ReasoningClosed bool
|
ReasoningClosed bool
|
||||||
|
|
||||||
@@ -37,6 +39,7 @@ type geminiToResponsesState struct {
|
|||||||
FuncArgsBuf map[int]*strings.Builder
|
FuncArgsBuf map[int]*strings.Builder
|
||||||
FuncNames map[int]string
|
FuncNames map[int]string
|
||||||
FuncCallIDs map[int]string
|
FuncCallIDs map[int]string
|
||||||
|
FuncDone map[int]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
|
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
|
||||||
@@ -45,6 +48,39 @@ var responseIDCounter uint64
|
|||||||
// funcCallIDCounter provides a process-wide unique counter for function call identifiers.
|
// funcCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
var funcCallIDCounter uint64
|
var funcCallIDCounter uint64
|
||||||
|
|
||||||
|
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
|
||||||
|
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
|
||||||
|
return originalRequestRawJSON
|
||||||
|
}
|
||||||
|
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
|
||||||
|
return requestRawJSON
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unwrapRequestRoot(root gjson.Result) gjson.Result {
|
||||||
|
req := root.Get("request")
|
||||||
|
if !req.Exists() {
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
if req.Get("model").Exists() || req.Get("input").Exists() || req.Get("instructions").Exists() {
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
|
||||||
|
func unwrapGeminiResponseRoot(root gjson.Result) gjson.Result {
|
||||||
|
resp := root.Get("response")
|
||||||
|
if !resp.Exists() {
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
// Vertex-style Gemini responses wrap the actual payload in a "response" object.
|
||||||
|
if resp.Get("candidates").Exists() || resp.Get("responseId").Exists() || resp.Get("usageMetadata").Exists() {
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
|
||||||
func emitEvent(event string, payload string) string {
|
func emitEvent(event string, payload string) string {
|
||||||
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
||||||
}
|
}
|
||||||
@@ -56,18 +92,37 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||||
FuncNames: make(map[int]string),
|
FuncNames: make(map[int]string),
|
||||||
FuncCallIDs: make(map[int]string),
|
FuncCallIDs: make(map[int]string),
|
||||||
|
FuncDone: make(map[int]bool),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
st := (*param).(*geminiToResponsesState)
|
st := (*param).(*geminiToResponsesState)
|
||||||
|
if st.FuncArgsBuf == nil {
|
||||||
|
st.FuncArgsBuf = make(map[int]*strings.Builder)
|
||||||
|
}
|
||||||
|
if st.FuncNames == nil {
|
||||||
|
st.FuncNames = make(map[int]string)
|
||||||
|
}
|
||||||
|
if st.FuncCallIDs == nil {
|
||||||
|
st.FuncCallIDs = make(map[int]string)
|
||||||
|
}
|
||||||
|
if st.FuncDone == nil {
|
||||||
|
st.FuncDone = make(map[int]bool)
|
||||||
|
}
|
||||||
|
|
||||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rawJSON = bytes.TrimSpace(rawJSON)
|
||||||
|
if len(rawJSON) == 0 || bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
root := gjson.ParseBytes(rawJSON)
|
root := gjson.ParseBytes(rawJSON)
|
||||||
if !root.Exists() {
|
if !root.Exists() {
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
root = unwrapGeminiResponseRoot(root)
|
||||||
|
|
||||||
var out []string
|
var out []string
|
||||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||||
@@ -98,19 +153,54 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
||||||
itemDone, _ = sjson.Set(itemDone, "item.id", st.ReasoningItemID)
|
itemDone, _ = sjson.Set(itemDone, "item.id", st.ReasoningItemID)
|
||||||
itemDone, _ = sjson.Set(itemDone, "output_index", st.ReasoningIndex)
|
itemDone, _ = sjson.Set(itemDone, "output_index", st.ReasoningIndex)
|
||||||
|
itemDone, _ = sjson.Set(itemDone, "item.encrypted_content", st.ReasoningEnc)
|
||||||
itemDone, _ = sjson.Set(itemDone, "item.summary.0.text", full)
|
itemDone, _ = sjson.Set(itemDone, "item.summary.0.text", full)
|
||||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||||
|
|
||||||
st.ReasoningClosed = true
|
st.ReasoningClosed = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper to finalize the assistant message in correct order.
|
||||||
|
// It emits response.output_text.done, response.content_part.done,
|
||||||
|
// and response.output_item.done exactly once.
|
||||||
|
finalizeMessage := func() {
|
||||||
|
if !st.MsgOpened || st.MsgClosed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fullText := st.ItemTextBuf.String()
|
||||||
|
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||||
|
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||||
|
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||||
|
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
|
||||||
|
done, _ = sjson.Set(done, "text", fullText)
|
||||||
|
out = append(out, emitEvent("response.output_text.done", done))
|
||||||
|
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||||
|
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||||
|
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||||
|
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
|
||||||
|
partDone, _ = sjson.Set(partDone, "part.text", fullText)
|
||||||
|
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||||
|
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||||
|
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||||
|
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
|
||||||
|
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||||
|
final, _ = sjson.Set(final, "item.content.0.text", fullText)
|
||||||
|
out = append(out, emitEvent("response.output_item.done", final))
|
||||||
|
|
||||||
|
st.MsgClosed = true
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize per-response fields and emit created/in_progress once
|
// Initialize per-response fields and emit created/in_progress once
|
||||||
if !st.Started {
|
if !st.Started {
|
||||||
if v := root.Get("responseId"); v.Exists() {
|
st.ResponseID = root.Get("responseId").String()
|
||||||
st.ResponseID = v.String()
|
if st.ResponseID == "" {
|
||||||
|
st.ResponseID = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(st.ResponseID, "resp_") {
|
||||||
|
st.ResponseID = fmt.Sprintf("resp_%s", st.ResponseID)
|
||||||
}
|
}
|
||||||
if v := root.Get("createTime"); v.Exists() {
|
if v := root.Get("createTime"); v.Exists() {
|
||||||
if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil {
|
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
|
||||||
st.CreatedAt = t.Unix()
|
st.CreatedAt = t.Unix()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -143,15 +233,21 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
// Ignore any late thought chunks after reasoning is finalized.
|
// Ignore any late thought chunks after reasoning is finalized.
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
if sig := part.Get("thoughtSignature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
|
||||||
|
st.ReasoningEnc = sig.String()
|
||||||
|
} else if sig = part.Get("thought_signature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
|
||||||
|
st.ReasoningEnc = sig.String()
|
||||||
|
}
|
||||||
if !st.ReasoningOpened {
|
if !st.ReasoningOpened {
|
||||||
st.ReasoningOpened = true
|
st.ReasoningOpened = true
|
||||||
st.ReasoningIndex = st.NextIndex
|
st.ReasoningIndex = st.NextIndex
|
||||||
st.NextIndex++
|
st.NextIndex++
|
||||||
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, st.ReasoningIndex)
|
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, st.ReasoningIndex)
|
||||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
|
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","encrypted_content":"","summary":[]}}`
|
||||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||||
item, _ = sjson.Set(item, "output_index", st.ReasoningIndex)
|
item, _ = sjson.Set(item, "output_index", st.ReasoningIndex)
|
||||||
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
|
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
|
||||||
|
item, _ = sjson.Set(item, "item.encrypted_content", st.ReasoningEnc)
|
||||||
out = append(out, emitEvent("response.output_item.added", item))
|
out = append(out, emitEvent("response.output_item.added", item))
|
||||||
partAdded := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
|
partAdded := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
|
||||||
partAdded, _ = sjson.Set(partAdded, "sequence_number", nextSeq())
|
partAdded, _ = sjson.Set(partAdded, "sequence_number", nextSeq())
|
||||||
@@ -191,9 +287,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
|
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
|
||||||
out = append(out, emitEvent("response.content_part.added", partAdded))
|
out = append(out, emitEvent("response.content_part.added", partAdded))
|
||||||
st.ItemTextBuf.Reset()
|
st.ItemTextBuf.Reset()
|
||||||
st.ItemTextBuf.WriteString(t.String())
|
|
||||||
}
|
}
|
||||||
st.TextBuf.WriteString(t.String())
|
st.TextBuf.WriteString(t.String())
|
||||||
|
st.ItemTextBuf.WriteString(t.String())
|
||||||
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
|
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
|
||||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||||
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
|
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
|
||||||
@@ -205,8 +301,10 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
|
|
||||||
// Function call
|
// Function call
|
||||||
if fc := part.Get("functionCall"); fc.Exists() {
|
if fc := part.Get("functionCall"); fc.Exists() {
|
||||||
// Before emitting function-call outputs, finalize reasoning if open.
|
// Before emitting function-call outputs, finalize reasoning and the message (if open).
|
||||||
|
// Responses streaming requires message done events before the next output_item.added.
|
||||||
finalizeReasoning()
|
finalizeReasoning()
|
||||||
|
finalizeMessage()
|
||||||
name := fc.Get("name").String()
|
name := fc.Get("name").String()
|
||||||
idx := st.NextIndex
|
idx := st.NextIndex
|
||||||
st.NextIndex++
|
st.NextIndex++
|
||||||
@@ -219,6 +317,14 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
}
|
}
|
||||||
st.FuncNames[idx] = name
|
st.FuncNames[idx] = name
|
||||||
|
|
||||||
|
argsJSON := "{}"
|
||||||
|
if args := fc.Get("args"); args.Exists() {
|
||||||
|
argsJSON = args.Raw
|
||||||
|
}
|
||||||
|
if st.FuncArgsBuf[idx].Len() == 0 && argsJSON != "" {
|
||||||
|
st.FuncArgsBuf[idx].WriteString(argsJSON)
|
||||||
|
}
|
||||||
|
|
||||||
// Emit item.added for function call
|
// Emit item.added for function call
|
||||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
|
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
|
||||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||||
@@ -228,10 +334,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
item, _ = sjson.Set(item, "item.name", name)
|
item, _ = sjson.Set(item, "item.name", name)
|
||||||
out = append(out, emitEvent("response.output_item.added", item))
|
out = append(out, emitEvent("response.output_item.added", item))
|
||||||
|
|
||||||
// Emit arguments delta (full args in one chunk)
|
// Emit arguments delta (full args in one chunk).
|
||||||
if args := fc.Get("args"); args.Exists() {
|
// When Gemini omits args, emit "{}" to keep Responses streaming event order consistent.
|
||||||
argsJSON := args.Raw
|
if argsJSON != "" {
|
||||||
st.FuncArgsBuf[idx].WriteString(argsJSON)
|
|
||||||
ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
|
ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
|
||||||
ad, _ = sjson.Set(ad, "sequence_number", nextSeq())
|
ad, _ = sjson.Set(ad, "sequence_number", nextSeq())
|
||||||
ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||||
@@ -240,6 +345,27 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
out = append(out, emitEvent("response.function_call_arguments.delta", ad))
|
out = append(out, emitEvent("response.function_call_arguments.delta", ad))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gemini emits the full function call payload at once, so we can finalize it immediately.
|
||||||
|
if !st.FuncDone[idx] {
|
||||||
|
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
|
||||||
|
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
|
||||||
|
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||||
|
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
|
||||||
|
fcDone, _ = sjson.Set(fcDone, "arguments", argsJSON)
|
||||||
|
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
|
||||||
|
|
||||||
|
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
|
||||||
|
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
||||||
|
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
|
||||||
|
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||||
|
itemDone, _ = sjson.Set(itemDone, "item.arguments", argsJSON)
|
||||||
|
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
|
||||||
|
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||||
|
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||||
|
|
||||||
|
st.FuncDone[idx] = true
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,28 +377,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
if fr := root.Get("candidates.0.finishReason"); fr.Exists() && fr.String() != "" {
|
if fr := root.Get("candidates.0.finishReason"); fr.Exists() && fr.String() != "" {
|
||||||
// Finalize reasoning first to keep ordering tight with last delta
|
// Finalize reasoning first to keep ordering tight with last delta
|
||||||
finalizeReasoning()
|
finalizeReasoning()
|
||||||
// Close message output if opened
|
finalizeMessage()
|
||||||
if st.MsgOpened {
|
|
||||||
fullText := st.ItemTextBuf.String()
|
|
||||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
|
||||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
|
||||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
|
||||||
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
|
|
||||||
done, _ = sjson.Set(done, "text", fullText)
|
|
||||||
out = append(out, emitEvent("response.output_text.done", done))
|
|
||||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
|
||||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
|
||||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
|
||||||
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
|
|
||||||
partDone, _ = sjson.Set(partDone, "part.text", fullText)
|
|
||||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
|
||||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
|
||||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
|
||||||
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
|
|
||||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
|
||||||
final, _ = sjson.Set(final, "item.content.0.text", fullText)
|
|
||||||
out = append(out, emitEvent("response.output_item.done", final))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close function calls
|
// Close function calls
|
||||||
if len(st.FuncArgsBuf) > 0 {
|
if len(st.FuncArgsBuf) > 0 {
|
||||||
@@ -289,6 +394,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, idx := range idxs {
|
for _, idx := range idxs {
|
||||||
|
if st.FuncDone[idx] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
args := "{}"
|
args := "{}"
|
||||||
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
|
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
|
||||||
args = b.String()
|
args = b.String()
|
||||||
@@ -308,6 +416,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
|
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
|
||||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||||
|
|
||||||
|
st.FuncDone[idx] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,8 +429,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
|
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
|
||||||
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
|
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
|
||||||
|
|
||||||
if requestRawJSON != nil {
|
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
|
||||||
req := gjson.ParseBytes(requestRawJSON)
|
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
|
||||||
if v := req.Get("instructions"); v.Exists() {
|
if v := req.Get("instructions"); v.Exists() {
|
||||||
completed, _ = sjson.Set(completed, "response.instructions", v.String())
|
completed, _ = sjson.Set(completed, "response.instructions", v.String())
|
||||||
}
|
}
|
||||||
@@ -383,41 +493,34 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compose outputs in encountered order: reasoning, message, function_calls
|
// Compose outputs in output_index order.
|
||||||
outputsWrapper := `{"arr":[]}`
|
outputsWrapper := `{"arr":[]}`
|
||||||
if st.ReasoningOpened {
|
for idx := 0; idx < st.NextIndex; idx++ {
|
||||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
if st.ReasoningOpened && idx == st.ReasoningIndex {
|
||||||
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
item := `{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]}`
|
||||||
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
||||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
item, _ = sjson.Set(item, "encrypted_content", st.ReasoningEnc)
|
||||||
}
|
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
||||||
if st.MsgOpened {
|
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
continue
|
||||||
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
|
||||||
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
|
||||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
|
||||||
}
|
|
||||||
if len(st.FuncArgsBuf) > 0 {
|
|
||||||
idxs := make([]int, 0, len(st.FuncArgsBuf))
|
|
||||||
for idx := range st.FuncArgsBuf {
|
|
||||||
idxs = append(idxs, idx)
|
|
||||||
}
|
}
|
||||||
for i := 0; i < len(idxs); i++ {
|
if st.MsgOpened && idx == st.MsgIndex {
|
||||||
for j := i + 1; j < len(idxs); j++ {
|
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||||
if idxs[j] < idxs[i] {
|
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
||||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
||||||
}
|
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||||
}
|
continue
|
||||||
}
|
}
|
||||||
for _, idx := range idxs {
|
|
||||||
args := ""
|
if callID, ok := st.FuncCallIDs[idx]; ok && callID != "" {
|
||||||
if b := st.FuncArgsBuf[idx]; b != nil {
|
args := "{}"
|
||||||
|
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
|
||||||
args = b.String()
|
args = b.String()
|
||||||
}
|
}
|
||||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||||
item, _ = sjson.Set(item, "arguments", args)
|
item, _ = sjson.Set(item, "arguments", args)
|
||||||
item, _ = sjson.Set(item, "call_id", st.FuncCallIDs[idx])
|
item, _ = sjson.Set(item, "call_id", callID)
|
||||||
item, _ = sjson.Set(item, "name", st.FuncNames[idx])
|
item, _ = sjson.Set(item, "name", st.FuncNames[idx])
|
||||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||||
}
|
}
|
||||||
@@ -431,8 +534,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
// input tokens = prompt + thoughts
|
// input tokens = prompt + thoughts
|
||||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||||
// cached_tokens not provided by Gemini; default to 0 for structure compatibility
|
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
|
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||||
// output tokens
|
// output tokens
|
||||||
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
||||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
|
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
|
||||||
@@ -460,6 +563,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
// ConvertGeminiResponseToOpenAIResponsesNonStream aggregates Gemini response JSON into a single OpenAI Responses JSON object.
|
// ConvertGeminiResponseToOpenAIResponsesNonStream aggregates Gemini response JSON into a single OpenAI Responses JSON object.
|
||||||
func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||||
root := gjson.ParseBytes(rawJSON)
|
root := gjson.ParseBytes(rawJSON)
|
||||||
|
root = unwrapGeminiResponseRoot(root)
|
||||||
|
|
||||||
// Base response scaffold
|
// Base response scaffold
|
||||||
resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`
|
resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`
|
||||||
@@ -478,15 +582,15 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
|||||||
// created_at: map from createTime if available
|
// created_at: map from createTime if available
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
if v := root.Get("createTime"); v.Exists() {
|
if v := root.Get("createTime"); v.Exists() {
|
||||||
if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil {
|
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
|
||||||
createdAt = t.Unix()
|
createdAt = t.Unix()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resp, _ = sjson.Set(resp, "created_at", createdAt)
|
resp, _ = sjson.Set(resp, "created_at", createdAt)
|
||||||
|
|
||||||
// Echo request fields when present; fallback model from response modelVersion
|
// Echo request fields when present; fallback model from response modelVersion
|
||||||
if len(requestRawJSON) > 0 {
|
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
|
||||||
req := gjson.ParseBytes(requestRawJSON)
|
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
|
||||||
if v := req.Get("instructions"); v.Exists() {
|
if v := req.Get("instructions"); v.Exists() {
|
||||||
resp, _ = sjson.Set(resp, "instructions", v.String())
|
resp, _ = sjson.Set(resp, "instructions", v.String())
|
||||||
}
|
}
|
||||||
@@ -636,8 +740,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
|||||||
// input tokens = prompt + thoughts
|
// input tokens = prompt + thoughts
|
||||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||||
// cached_tokens not provided by Gemini; default to 0 for structure compatibility
|
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", 0)
|
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||||
// output tokens
|
// output tokens
|
||||||
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
||||||
resp, _ = sjson.Set(resp, "usage.output_tokens", v.Int())
|
resp, _ = sjson.Set(resp, "usage.output_tokens", v.Int())
|
||||||
|
|||||||
@@ -0,0 +1,353 @@
|
|||||||
|
package responses
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseSSEEvent(t *testing.T, chunk string) (string, gjson.Result) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
lines := strings.Split(chunk, "\n")
|
||||||
|
if len(lines) < 2 {
|
||||||
|
t.Fatalf("unexpected SSE chunk: %q", chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||||
|
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||||
|
if !gjson.Valid(dataLine) {
|
||||||
|
t.Fatalf("invalid SSE data JSON: %q", dataLine)
|
||||||
|
}
|
||||||
|
return event, gjson.Parse(dataLine)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertGeminiResponseToOpenAIResponses_UnwrapAndAggregateText(t *testing.T) {
|
||||||
|
// Vertex-style Gemini stream wraps the actual response payload under "response".
|
||||||
|
// This test ensures we unwrap and that output_text.done contains the full text.
|
||||||
|
in := []string{
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"让"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"我先"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"了解"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"mcp__serena__list_dir","args":{"recursive":false,"relative_path":"internal"},"id":"toolu_1"}}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":2},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
originalReq := []byte(`{"instructions":"test instructions","model":"gpt-5","max_output_tokens":123}`)
|
||||||
|
|
||||||
|
var param any
|
||||||
|
var out []string
|
||||||
|
for _, line := range in {
|
||||||
|
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", originalReq, nil, []byte(line), ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
gotTextDone bool
|
||||||
|
gotMessageDone bool
|
||||||
|
gotResponseDone bool
|
||||||
|
gotFuncDone bool
|
||||||
|
|
||||||
|
textDone string
|
||||||
|
messageText string
|
||||||
|
responseID string
|
||||||
|
instructions string
|
||||||
|
cachedTokens int64
|
||||||
|
|
||||||
|
funcName string
|
||||||
|
funcArgs string
|
||||||
|
|
||||||
|
posTextDone = -1
|
||||||
|
posPartDone = -1
|
||||||
|
posMessageDone = -1
|
||||||
|
posFuncAdded = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, chunk := range out {
|
||||||
|
ev, data := parseSSEEvent(t, chunk)
|
||||||
|
switch ev {
|
||||||
|
case "response.output_text.done":
|
||||||
|
gotTextDone = true
|
||||||
|
if posTextDone == -1 {
|
||||||
|
posTextDone = i
|
||||||
|
}
|
||||||
|
textDone = data.Get("text").String()
|
||||||
|
case "response.content_part.done":
|
||||||
|
if posPartDone == -1 {
|
||||||
|
posPartDone = i
|
||||||
|
}
|
||||||
|
case "response.output_item.done":
|
||||||
|
switch data.Get("item.type").String() {
|
||||||
|
case "message":
|
||||||
|
gotMessageDone = true
|
||||||
|
if posMessageDone == -1 {
|
||||||
|
posMessageDone = i
|
||||||
|
}
|
||||||
|
messageText = data.Get("item.content.0.text").String()
|
||||||
|
case "function_call":
|
||||||
|
gotFuncDone = true
|
||||||
|
funcName = data.Get("item.name").String()
|
||||||
|
funcArgs = data.Get("item.arguments").String()
|
||||||
|
}
|
||||||
|
case "response.output_item.added":
|
||||||
|
if data.Get("item.type").String() == "function_call" && posFuncAdded == -1 {
|
||||||
|
posFuncAdded = i
|
||||||
|
}
|
||||||
|
case "response.completed":
|
||||||
|
gotResponseDone = true
|
||||||
|
responseID = data.Get("response.id").String()
|
||||||
|
instructions = data.Get("response.instructions").String()
|
||||||
|
cachedTokens = data.Get("response.usage.input_tokens_details.cached_tokens").Int()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gotTextDone {
|
||||||
|
t.Fatalf("missing response.output_text.done event")
|
||||||
|
}
|
||||||
|
if posTextDone == -1 || posPartDone == -1 || posMessageDone == -1 || posFuncAdded == -1 {
|
||||||
|
t.Fatalf("missing ordering events: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
|
||||||
|
}
|
||||||
|
if !(posTextDone < posPartDone && posPartDone < posMessageDone && posMessageDone < posFuncAdded) {
|
||||||
|
t.Fatalf("unexpected message/function ordering: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
|
||||||
|
}
|
||||||
|
if !gotMessageDone {
|
||||||
|
t.Fatalf("missing message response.output_item.done event")
|
||||||
|
}
|
||||||
|
if !gotFuncDone {
|
||||||
|
t.Fatalf("missing function_call response.output_item.done event")
|
||||||
|
}
|
||||||
|
if !gotResponseDone {
|
||||||
|
t.Fatalf("missing response.completed event")
|
||||||
|
}
|
||||||
|
|
||||||
|
if textDone != "让我先了解" {
|
||||||
|
t.Fatalf("unexpected output_text.done text: got %q", textDone)
|
||||||
|
}
|
||||||
|
if messageText != "让我先了解" {
|
||||||
|
t.Fatalf("unexpected message done text: got %q", messageText)
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseID != "resp_req_vrtx_1" {
|
||||||
|
t.Fatalf("unexpected response id: got %q", responseID)
|
||||||
|
}
|
||||||
|
if instructions != "test instructions" {
|
||||||
|
t.Fatalf("unexpected instructions echo: got %q", instructions)
|
||||||
|
}
|
||||||
|
if cachedTokens != 2 {
|
||||||
|
t.Fatalf("unexpected cached token count: got %d", cachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if funcName != "mcp__serena__list_dir" {
|
||||||
|
t.Fatalf("unexpected function name: got %q", funcName)
|
||||||
|
}
|
||||||
|
if !gjson.Valid(funcArgs) {
|
||||||
|
t.Fatalf("invalid function arguments JSON: %q", funcArgs)
|
||||||
|
}
|
||||||
|
if gjson.Get(funcArgs, "recursive").Bool() != false {
|
||||||
|
t.Fatalf("unexpected recursive arg: %v", gjson.Get(funcArgs, "recursive").Value())
|
||||||
|
}
|
||||||
|
if gjson.Get(funcArgs, "relative_path").String() != "internal" {
|
||||||
|
t.Fatalf("unexpected relative_path arg: %q", gjson.Get(funcArgs, "relative_path").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertGeminiResponseToOpenAIResponses_ReasoningEncryptedContent(t *testing.T) {
|
||||||
|
sig := "RXE0RENrZ0lDeEFDR0FJcVFOZDdjUzlleGFuRktRdFcvSzNyZ2MvWDNCcDQ4RmxSbGxOWUlOVU5kR1l1UHMrMGdkMVp0Vkg3ekdKU0g4YVljc2JjN3lNK0FrdGpTNUdqamI4T3Z0VVNETzdQd3pmcFhUOGl3U3hXUEJvTVFRQ09mWTFyMEtTWGZxUUlJakFqdmFGWk83RW1XRlBKckJVOVpkYzdDKw=="
|
||||||
|
in := []string{
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"thoughtSignature":"` + sig + `","text":""}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"a"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hello"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var param any
|
||||||
|
var out []string
|
||||||
|
for _, line := range in {
|
||||||
|
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
addedEnc string
|
||||||
|
doneEnc string
|
||||||
|
)
|
||||||
|
for _, chunk := range out {
|
||||||
|
ev, data := parseSSEEvent(t, chunk)
|
||||||
|
switch ev {
|
||||||
|
case "response.output_item.added":
|
||||||
|
if data.Get("item.type").String() == "reasoning" {
|
||||||
|
addedEnc = data.Get("item.encrypted_content").String()
|
||||||
|
}
|
||||||
|
case "response.output_item.done":
|
||||||
|
if data.Get("item.type").String() == "reasoning" {
|
||||||
|
doneEnc = data.Get("item.encrypted_content").String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if addedEnc != sig {
|
||||||
|
t.Fatalf("unexpected encrypted_content in response.output_item.added: got %q", addedEnc)
|
||||||
|
}
|
||||||
|
if doneEnc != sig {
|
||||||
|
t.Fatalf("unexpected encrypted_content in response.output_item.done: got %q", doneEnc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertGeminiResponseToOpenAIResponses_FunctionCallEventOrder(t *testing.T) {
|
||||||
|
in := []string{
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool1"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool2","args":{"a":1}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var param any
|
||||||
|
var out []string
|
||||||
|
for _, line := range in {
|
||||||
|
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
posAdded := []int{-1, -1, -1}
|
||||||
|
posArgsDelta := []int{-1, -1, -1}
|
||||||
|
posArgsDone := []int{-1, -1, -1}
|
||||||
|
posItemDone := []int{-1, -1, -1}
|
||||||
|
posCompleted := -1
|
||||||
|
deltaByIndex := map[int]string{}
|
||||||
|
|
||||||
|
for i, chunk := range out {
|
||||||
|
ev, data := parseSSEEvent(t, chunk)
|
||||||
|
switch ev {
|
||||||
|
case "response.output_item.added":
|
||||||
|
if data.Get("item.type").String() != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idx := int(data.Get("output_index").Int())
|
||||||
|
if idx >= 0 && idx < len(posAdded) {
|
||||||
|
posAdded[idx] = i
|
||||||
|
}
|
||||||
|
case "response.function_call_arguments.delta":
|
||||||
|
idx := int(data.Get("output_index").Int())
|
||||||
|
if idx >= 0 && idx < len(posArgsDelta) {
|
||||||
|
posArgsDelta[idx] = i
|
||||||
|
deltaByIndex[idx] = data.Get("delta").String()
|
||||||
|
}
|
||||||
|
case "response.function_call_arguments.done":
|
||||||
|
idx := int(data.Get("output_index").Int())
|
||||||
|
if idx >= 0 && idx < len(posArgsDone) {
|
||||||
|
posArgsDone[idx] = i
|
||||||
|
}
|
||||||
|
case "response.output_item.done":
|
||||||
|
if data.Get("item.type").String() != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idx := int(data.Get("output_index").Int())
|
||||||
|
if idx >= 0 && idx < len(posItemDone) {
|
||||||
|
posItemDone[idx] = i
|
||||||
|
}
|
||||||
|
case "response.completed":
|
||||||
|
posCompleted = i
|
||||||
|
|
||||||
|
output := data.Get("response.output")
|
||||||
|
if !output.Exists() || !output.IsArray() {
|
||||||
|
t.Fatalf("missing response.output in response.completed")
|
||||||
|
}
|
||||||
|
if len(output.Array()) != 3 {
|
||||||
|
t.Fatalf("unexpected response.output length: got %d", len(output.Array()))
|
||||||
|
}
|
||||||
|
if data.Get("response.output.0.name").String() != "tool0" || data.Get("response.output.0.arguments").String() != "{}" {
|
||||||
|
t.Fatalf("unexpected output[0]: %s", data.Get("response.output.0").Raw)
|
||||||
|
}
|
||||||
|
if data.Get("response.output.1.name").String() != "tool1" || data.Get("response.output.1.arguments").String() != "{}" {
|
||||||
|
t.Fatalf("unexpected output[1]: %s", data.Get("response.output.1").Raw)
|
||||||
|
}
|
||||||
|
if data.Get("response.output.2.name").String() != "tool2" {
|
||||||
|
t.Fatalf("unexpected output[2] name: %s", data.Get("response.output.2").Raw)
|
||||||
|
}
|
||||||
|
if !gjson.Valid(data.Get("response.output.2.arguments").String()) {
|
||||||
|
t.Fatalf("unexpected output[2] arguments: %q", data.Get("response.output.2.arguments").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if posCompleted == -1 {
|
||||||
|
t.Fatalf("missing response.completed event")
|
||||||
|
}
|
||||||
|
for idx := 0; idx < 3; idx++ {
|
||||||
|
if posAdded[idx] == -1 || posArgsDelta[idx] == -1 || posArgsDone[idx] == -1 || posItemDone[idx] == -1 {
|
||||||
|
t.Fatalf("missing function call events for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
|
||||||
|
}
|
||||||
|
if !(posAdded[idx] < posArgsDelta[idx] && posArgsDelta[idx] < posArgsDone[idx] && posArgsDone[idx] < posItemDone[idx]) {
|
||||||
|
t.Fatalf("unexpected ordering for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
|
||||||
|
}
|
||||||
|
if idx > 0 && !(posItemDone[idx-1] < posAdded[idx]) {
|
||||||
|
t.Fatalf("function call events overlap between %d and %d: prevDone=%d nextAdded=%d", idx-1, idx, posItemDone[idx-1], posAdded[idx])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if deltaByIndex[0] != "{}" {
|
||||||
|
t.Fatalf("unexpected delta for output_index 0: got %q", deltaByIndex[0])
|
||||||
|
}
|
||||||
|
if deltaByIndex[1] != "{}" {
|
||||||
|
t.Fatalf("unexpected delta for output_index 1: got %q", deltaByIndex[1])
|
||||||
|
}
|
||||||
|
if deltaByIndex[2] == "" || !gjson.Valid(deltaByIndex[2]) || gjson.Get(deltaByIndex[2], "a").Int() != 1 {
|
||||||
|
t.Fatalf("unexpected delta for output_index 2: got %q", deltaByIndex[2])
|
||||||
|
}
|
||||||
|
if !(posItemDone[2] < posCompleted) {
|
||||||
|
t.Fatalf("response.completed should be after last output_item.done: last=%d completed=%d", posItemDone[2], posCompleted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertGeminiResponseToOpenAIResponses_ResponseOutputOrdering(t *testing.T) {
|
||||||
|
in := []string{
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0","args":{"x":"y"}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hi"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||||
|
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var param any
|
||||||
|
var out []string
|
||||||
|
for _, line := range in {
|
||||||
|
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
posFuncDone := -1
|
||||||
|
posMsgAdded := -1
|
||||||
|
posCompleted := -1
|
||||||
|
|
||||||
|
for i, chunk := range out {
|
||||||
|
ev, data := parseSSEEvent(t, chunk)
|
||||||
|
switch ev {
|
||||||
|
case "response.output_item.done":
|
||||||
|
if data.Get("item.type").String() == "function_call" && data.Get("output_index").Int() == 0 {
|
||||||
|
posFuncDone = i
|
||||||
|
}
|
||||||
|
case "response.output_item.added":
|
||||||
|
if data.Get("item.type").String() == "message" && data.Get("output_index").Int() == 1 {
|
||||||
|
posMsgAdded = i
|
||||||
|
}
|
||||||
|
case "response.completed":
|
||||||
|
posCompleted = i
|
||||||
|
if data.Get("response.output.0.type").String() != "function_call" {
|
||||||
|
t.Fatalf("expected response.output[0] to be function_call: %s", data.Get("response.output.0").Raw)
|
||||||
|
}
|
||||||
|
if data.Get("response.output.1.type").String() != "message" {
|
||||||
|
t.Fatalf("expected response.output[1] to be message: %s", data.Get("response.output.1").Raw)
|
||||||
|
}
|
||||||
|
if data.Get("response.output.1.content.0.text").String() != "hi" {
|
||||||
|
t.Fatalf("unexpected message text in response.output[1]: %s", data.Get("response.output.1").Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if posFuncDone == -1 || posMsgAdded == -1 || posCompleted == -1 {
|
||||||
|
t.Fatalf("missing required events: funcDone=%d msgAdded=%d completed=%d", posFuncDone, posMsgAdded, posCompleted)
|
||||||
|
}
|
||||||
|
if !(posFuncDone < posMsgAdded) {
|
||||||
|
t.Fatalf("expected function_call to complete before message is added: funcDone=%d msgAdded=%d", posFuncDone, posMsgAdded)
|
||||||
|
}
|
||||||
|
if !(posMsgAdded < posCompleted) {
|
||||||
|
t.Fatalf("expected response.completed after message added: msgAdded=%d completed=%d", posMsgAdded, posCompleted)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -499,6 +499,16 @@ func shortenToolNameIfNeeded(name string) string {
|
|||||||
return name[:limit]
|
return name[:limit]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureKiroInputSchema(parameters interface{}) interface{} {
|
||||||
|
if parameters != nil {
|
||||||
|
return parameters
|
||||||
|
}
|
||||||
|
return map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
||||||
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||||
var kiroTools []KiroToolWrapper
|
var kiroTools []KiroToolWrapper
|
||||||
@@ -509,7 +519,12 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
|||||||
for _, tool := range tools.Array() {
|
for _, tool := range tools.Array() {
|
||||||
name := tool.Get("name").String()
|
name := tool.Get("name").String()
|
||||||
description := tool.Get("description").String()
|
description := tool.Get("description").String()
|
||||||
inputSchema := tool.Get("input_schema").Value()
|
inputSchemaResult := tool.Get("input_schema")
|
||||||
|
var inputSchema interface{}
|
||||||
|
if inputSchemaResult.Exists() && inputSchemaResult.Type != gjson.Null {
|
||||||
|
inputSchema = inputSchemaResult.Value()
|
||||||
|
}
|
||||||
|
inputSchema = ensureKiroInputSchema(inputSchema)
|
||||||
|
|
||||||
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||||
originalName := name
|
originalName := name
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWr
|
|||||||
|
|
||||||
name := kirocommon.GetString(fn, "name")
|
name := kirocommon.GetString(fn, "name")
|
||||||
description := kirocommon.GetString(fn, "description")
|
description := kirocommon.GetString(fn, "description")
|
||||||
parameters := fn["parameters"]
|
parameters := ensureKiroInputSchema(fn["parameters"])
|
||||||
|
|
||||||
if name == "" {
|
if name == "" {
|
||||||
continue
|
continue
|
||||||
@@ -368,4 +368,4 @@ func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]i
|
|||||||
// LogStreamEvent logs a streaming event for debugging
|
// LogStreamEvent logs a streaming event for debugging
|
||||||
func LogStreamEvent(eventType, data string) {
|
func LogStreamEvent(eventType, data string) {
|
||||||
log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
|
log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -381,6 +381,16 @@ func shortenToolNameIfNeeded(name string) string {
|
|||||||
return name[:limit]
|
return name[:limit]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureKiroInputSchema(parameters interface{}) interface{} {
|
||||||
|
if parameters != nil {
|
||||||
|
return parameters
|
||||||
|
}
|
||||||
|
return map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
|
// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
|
||||||
func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||||
var kiroTools []KiroToolWrapper
|
var kiroTools []KiroToolWrapper
|
||||||
@@ -401,7 +411,12 @@ func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
|||||||
|
|
||||||
name := fn.Get("name").String()
|
name := fn.Get("name").String()
|
||||||
description := fn.Get("description").String()
|
description := fn.Get("description").String()
|
||||||
parameters := fn.Get("parameters").Value()
|
parametersResult := fn.Get("parameters")
|
||||||
|
var parameters interface{}
|
||||||
|
if parametersResult.Exists() && parametersResult.Type != gjson.Null {
|
||||||
|
parameters = parametersResult.Value()
|
||||||
|
}
|
||||||
|
parameters = ensureKiroInputSchema(parameters)
|
||||||
|
|
||||||
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||||
originalName := name
|
originalName := name
|
||||||
|
|||||||
@@ -88,13 +88,15 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
|||||||
var messagesJSON = "[]"
|
var messagesJSON = "[]"
|
||||||
|
|
||||||
// Handle system message first
|
// Handle system message first
|
||||||
systemMsgJSON := `{"role":"system","content":[{"type":"text","text":"Use ANY tool, the parameters MUST accord with RFC 8259 (The JavaScript Object Notation (JSON) Data Interchange Format), the keys and value MUST be enclosed in double quotes."}]}`
|
systemMsgJSON := `{"role":"system","content":[]}`
|
||||||
|
hasSystemContent := false
|
||||||
if system := root.Get("system"); system.Exists() {
|
if system := root.Get("system"); system.Exists() {
|
||||||
if system.Type == gjson.String {
|
if system.Type == gjson.String {
|
||||||
if system.String() != "" {
|
if system.String() != "" {
|
||||||
oldSystem := `{"type":"text","text":""}`
|
oldSystem := `{"type":"text","text":""}`
|
||||||
oldSystem, _ = sjson.Set(oldSystem, "text", system.String())
|
oldSystem, _ = sjson.Set(oldSystem, "text", system.String())
|
||||||
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", oldSystem)
|
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", oldSystem)
|
||||||
|
hasSystemContent = true
|
||||||
}
|
}
|
||||||
} else if system.Type == gjson.JSON {
|
} else if system.Type == gjson.JSON {
|
||||||
if system.IsArray() {
|
if system.IsArray() {
|
||||||
@@ -102,12 +104,16 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
|||||||
for i := 0; i < len(systemResults); i++ {
|
for i := 0; i < len(systemResults); i++ {
|
||||||
if contentItem, ok := convertClaudeContentPart(systemResults[i]); ok {
|
if contentItem, ok := convertClaudeContentPart(systemResults[i]); ok {
|
||||||
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", contentItem)
|
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", contentItem)
|
||||||
|
hasSystemContent = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
|
// Only add system message if it has content
|
||||||
|
if hasSystemContent {
|
||||||
|
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
|
||||||
|
}
|
||||||
|
|
||||||
// Process Anthropic messages
|
// Process Anthropic messages
|
||||||
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
|
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
|
||||||
|
|||||||
@@ -181,11 +181,11 @@ func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
|
|||||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||||
resultJSON := gjson.ParseBytes(result)
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
|
||||||
// Find the relevant message (skip system message at index 0)
|
// Find the relevant message
|
||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
if len(messages) < 2 {
|
if len(messages) < 1 {
|
||||||
if tt.wantHasReasoningContent || tt.wantHasContent {
|
if tt.wantHasReasoningContent || tt.wantHasContent {
|
||||||
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages))
|
t.Fatalf("Expected at least 1 message, got %d", len(messages))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -272,15 +272,15 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T)
|
|||||||
|
|
||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages
|
// Should have: user + assistant (thinking-only) + user = 3 messages
|
||||||
if len(messages) != 4 {
|
if len(messages) != 3 {
|
||||||
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 3 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the assistant message (index 2) has reasoning_content
|
// Check the assistant message (index 1) has reasoning_content
|
||||||
assistantMsg := messages[2]
|
assistantMsg := messages[1]
|
||||||
if assistantMsg.Get("role").String() != "assistant" {
|
if assistantMsg.Get("role").String() != "assistant" {
|
||||||
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String())
|
t.Errorf("Expected message[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !assistantMsg.Get("reasoning_content").Exists() {
|
if !assistantMsg.Get("reasoning_content").Exists() {
|
||||||
@@ -292,6 +292,104 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToOpenAI_SystemMessageScenarios(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputJSON string
|
||||||
|
wantHasSys bool
|
||||||
|
wantSysText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty string system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": "",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "String system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": "Be helpful",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: true,
|
||||||
|
wantSysText: "Be helpful",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Array system field with text",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": [{"type": "text", "text": "Array system"}],
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: true,
|
||||||
|
wantSysText: "Array system",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Array system field with multiple text blocks",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": [
|
||||||
|
{"type": "text", "text": "Block 1"},
|
||||||
|
{"type": "text", "text": "Block 2"}
|
||||||
|
],
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: true,
|
||||||
|
wantSysText: "Block 2", // We will update the test logic to check all blocks or specifically the second one
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||||
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
|
hasSys := false
|
||||||
|
var sysMsg gjson.Result
|
||||||
|
if len(messages) > 0 && messages[0].Get("role").String() == "system" {
|
||||||
|
hasSys = true
|
||||||
|
sysMsg = messages[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasSys != tt.wantHasSys {
|
||||||
|
t.Errorf("got hasSystem = %v, want %v", hasSys, tt.wantHasSys)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantHasSys {
|
||||||
|
// Check content - it could be string or array in OpenAI
|
||||||
|
content := sysMsg.Get("content")
|
||||||
|
var gotText string
|
||||||
|
if content.IsArray() {
|
||||||
|
arr := content.Array()
|
||||||
|
if len(arr) > 0 {
|
||||||
|
// Get the last element's text for validation
|
||||||
|
gotText = arr[len(arr)-1].Get("text").String()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gotText = content.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantSysText != "" && gotText != tt.wantSysText {
|
||||||
|
t.Errorf("got system text = %q, want %q", gotText, tt.wantSysText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
||||||
inputJSON := `{
|
inputJSON := `{
|
||||||
"model": "claude-3-opus",
|
"model": "claude-3-opus",
|
||||||
@@ -318,39 +416,35 @@ func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
|||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
|
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
|
||||||
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after)
|
// Correct order: assistant(tool_calls) + tool(result) + user(before+after)
|
||||||
if len(messages) != 4 {
|
if len(messages) != 3 {
|
||||||
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if messages[0].Get("role").String() != "system" {
|
if messages[0].Get("role").String() != "assistant" || !messages[0].Get("tool_calls").Exists() {
|
||||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
t.Fatalf("Expected messages[0] to be assistant tool_calls, got %s: %s", messages[0].Get("role").String(), messages[0].Raw)
|
||||||
}
|
|
||||||
|
|
||||||
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
|
|
||||||
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
|
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
|
||||||
if messages[2].Get("role").String() != "tool" {
|
if messages[1].Get("role").String() != "tool" {
|
||||||
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String())
|
t.Fatalf("Expected messages[1] to be tool (must follow tool_calls), got %s", messages[1].Get("role").String())
|
||||||
}
|
}
|
||||||
if got := messages[2].Get("tool_call_id").String(); got != "call_1" {
|
if got := messages[1].Get("tool_call_id").String(); got != "call_1" {
|
||||||
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
|
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
|
||||||
}
|
}
|
||||||
if got := messages[2].Get("content").String(); got != "tool ok" {
|
if got := messages[1].Get("content").String(); got != "tool ok" {
|
||||||
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
|
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// User message comes after tool message
|
// User message comes after tool message
|
||||||
if messages[3].Get("role").String() != "user" {
|
if messages[2].Get("role").String() != "user" {
|
||||||
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String())
|
t.Fatalf("Expected messages[2] to be user, got %s", messages[2].Get("role").String())
|
||||||
}
|
}
|
||||||
// User message should contain both "before" and "after" text
|
// User message should contain both "before" and "after" text
|
||||||
if got := messages[3].Get("content.0.text").String(); got != "before" {
|
if got := messages[2].Get("content.0.text").String(); got != "before" {
|
||||||
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
|
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
|
||||||
}
|
}
|
||||||
if got := messages[3].Get("content.1.text").String(); got != "after" {
|
if got := messages[2].Get("content.1.text").String(); got != "after" {
|
||||||
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
|
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -378,16 +472,16 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
|
|||||||
resultJSON := gjson.ParseBytes(result)
|
resultJSON := gjson.ParseBytes(result)
|
||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// system + assistant(tool_calls) + tool(result)
|
// assistant(tool_calls) + tool(result)
|
||||||
if len(messages) != 3 {
|
if len(messages) != 2 {
|
||||||
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if messages[2].Get("role").String() != "tool" {
|
if messages[1].Get("role").String() != "tool" {
|
||||||
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String())
|
t.Fatalf("Expected messages[1] to be tool, got %s", messages[1].Get("role").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
toolContent := messages[2].Get("content").String()
|
toolContent := messages[1].Get("content").String()
|
||||||
parsed := gjson.Parse(toolContent)
|
parsed := gjson.Parse(toolContent)
|
||||||
if parsed.Get("foo").String() != "bar" {
|
if parsed.Get("foo").String() != "bar" {
|
||||||
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
|
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
|
||||||
@@ -414,18 +508,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T
|
|||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// New behavior: content + tool_calls unified in single assistant message
|
// New behavior: content + tool_calls unified in single assistant message
|
||||||
// Expect: system + assistant(content[pre,post] + tool_calls)
|
// Expect: assistant(content[pre,post] + tool_calls)
|
||||||
if len(messages) != 2 {
|
if len(messages) != 1 {
|
||||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if messages[0].Get("role").String() != "system" {
|
assistantMsg := messages[0]
|
||||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
|
||||||
}
|
|
||||||
|
|
||||||
assistantMsg := messages[1]
|
|
||||||
if assistantMsg.Get("role").String() != "assistant" {
|
if assistantMsg.Get("role").String() != "assistant" {
|
||||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should have both content and tool_calls in same message
|
// Should have both content and tool_calls in same message
|
||||||
@@ -470,14 +560,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
|
|||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// New behavior: all content, thinking, and tool_calls unified in single assistant message
|
// New behavior: all content, thinking, and tool_calls unified in single assistant message
|
||||||
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
|
// Expect: assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
|
||||||
if len(messages) != 2 {
|
if len(messages) != 1 {
|
||||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantMsg := messages[1]
|
assistantMsg := messages[0]
|
||||||
if assistantMsg.Get("role").String() != "assistant" {
|
if assistantMsg.Get("role").String() != "assistant" {
|
||||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should have content with both pre and post
|
// Should have content with both pre and post
|
||||||
|
|||||||
@@ -289,21 +289,17 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
|||||||
// Only process if usage has actual values (not null)
|
// Only process if usage has actual values (not null)
|
||||||
if param.FinishReason != "" {
|
if param.FinishReason != "" {
|
||||||
usage := root.Get("usage")
|
usage := root.Get("usage")
|
||||||
var inputTokens, outputTokens int64
|
var inputTokens, outputTokens, cachedTokens int64
|
||||||
if usage.Exists() && usage.Type != gjson.Null {
|
if usage.Exists() && usage.Type != gjson.Null {
|
||||||
// Check if usage has actual token counts
|
inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage)
|
||||||
promptTokens := usage.Get("prompt_tokens")
|
|
||||||
completionTokens := usage.Get("completion_tokens")
|
|
||||||
|
|
||||||
if promptTokens.Exists() && completionTokens.Exists() {
|
|
||||||
inputTokens = promptTokens.Int()
|
|
||||||
outputTokens = completionTokens.Int()
|
|
||||||
}
|
|
||||||
// Send message_delta with usage
|
// Send message_delta with usage
|
||||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
||||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
}
|
||||||
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
||||||
param.MessageDeltaSent = true
|
param.MessageDeltaSent = true
|
||||||
|
|
||||||
@@ -423,13 +419,12 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
|
|||||||
|
|
||||||
// Set usage information
|
// Set usage information
|
||||||
if usage := root.Get("usage"); usage.Exists() {
|
if usage := root.Get("usage"); usage.Exists() {
|
||||||
out, _ = sjson.Set(out, "usage.input_tokens", usage.Get("prompt_tokens").Int())
|
inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(usage)
|
||||||
out, _ = sjson.Set(out, "usage.output_tokens", usage.Get("completion_tokens").Int())
|
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||||
reasoningTokens := int64(0)
|
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||||
if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
|
if cachedTokens > 0 {
|
||||||
reasoningTokens = v.Int()
|
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
}
|
}
|
||||||
out, _ = sjson.Set(out, "usage.reasoning_tokens", reasoningTokens)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{out}
|
return []string{out}
|
||||||
@@ -674,8 +669,12 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
|||||||
}
|
}
|
||||||
|
|
||||||
if respUsage := root.Get("usage"); respUsage.Exists() {
|
if respUsage := root.Get("usage"); respUsage.Exists() {
|
||||||
out, _ = sjson.Set(out, "usage.input_tokens", respUsage.Get("prompt_tokens").Int())
|
inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(respUsage)
|
||||||
out, _ = sjson.Set(out, "usage.output_tokens", respUsage.Get("completion_tokens").Int())
|
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||||
|
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stopReasonSet {
|
if !stopReasonSet {
|
||||||
@@ -692,3 +691,23 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
|||||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractOpenAIUsage(usage gjson.Result) (int64, int64, int64) {
|
||||||
|
if !usage.Exists() || usage.Type == gjson.Null {
|
||||||
|
return 0, 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
inputTokens := usage.Get("prompt_tokens").Int()
|
||||||
|
outputTokens := usage.Get("completion_tokens").Int()
|
||||||
|
cachedTokens := usage.Get("prompt_tokens_details.cached_tokens").Int()
|
||||||
|
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
if inputTokens >= cachedTokens {
|
||||||
|
inputTokens -= cachedTokens
|
||||||
|
} else {
|
||||||
|
inputTokens = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputTokens, outputTokens, cachedTokens
|
||||||
|
}
|
||||||
|
|||||||
@@ -77,7 +77,13 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Candidate count (OpenAI 'n' parameter)
|
||||||
|
if candidateCount := genConfig.Get("candidateCount"); candidateCount.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "n", candidateCount.Int())
|
||||||
|
}
|
||||||
|
|
||||||
// Map Gemini thinkingConfig to OpenAI reasoning_effort.
|
// Map Gemini thinkingConfig to OpenAI reasoning_effort.
|
||||||
|
// Always perform conversion to support allowCompat models that may not be in registry
|
||||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||||
|
|||||||
@@ -12,10 +12,23 @@ import (
|
|||||||
|
|
||||||
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||||
|
|
||||||
|
const placeholderReasonDescription = "Brief explanation of why you are calling this tool"
|
||||||
|
|
||||||
// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
|
// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
|
||||||
// It handles unsupported keywords, type flattening, and schema simplification while preserving
|
// It handles unsupported keywords, type flattening, and schema simplification while preserving
|
||||||
// semantic information as description hints.
|
// semantic information as description hints.
|
||||||
func CleanJSONSchemaForAntigravity(jsonStr string) string {
|
func CleanJSONSchemaForAntigravity(jsonStr string) string {
|
||||||
|
return cleanJSONSchema(jsonStr, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini tool calling.
|
||||||
|
// It removes unsupported keywords and simplifies schemas, without adding empty-schema placeholders.
|
||||||
|
func CleanJSONSchemaForGemini(jsonStr string) string {
|
||||||
|
return cleanJSONSchema(jsonStr, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanJSONSchema performs the core cleaning operations on the JSON schema.
|
||||||
|
func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
|
||||||
// Phase 1: Convert and add hints
|
// Phase 1: Convert and add hints
|
||||||
jsonStr = convertRefsToHints(jsonStr)
|
jsonStr = convertRefsToHints(jsonStr)
|
||||||
jsonStr = convertConstToEnum(jsonStr)
|
jsonStr = convertConstToEnum(jsonStr)
|
||||||
@@ -31,10 +44,94 @@ func CleanJSONSchemaForAntigravity(jsonStr string) string {
|
|||||||
|
|
||||||
// Phase 3: Cleanup
|
// Phase 3: Cleanup
|
||||||
jsonStr = removeUnsupportedKeywords(jsonStr)
|
jsonStr = removeUnsupportedKeywords(jsonStr)
|
||||||
|
if !addPlaceholder {
|
||||||
|
// Gemini schema cleanup: remove nullable/title and placeholder-only fields.
|
||||||
|
jsonStr = removeKeywords(jsonStr, []string{"nullable", "title"})
|
||||||
|
jsonStr = removePlaceholderFields(jsonStr)
|
||||||
|
}
|
||||||
jsonStr = cleanupRequiredFields(jsonStr)
|
jsonStr = cleanupRequiredFields(jsonStr)
|
||||||
|
|
||||||
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
|
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
|
||||||
jsonStr = addEmptySchemaPlaceholder(jsonStr)
|
if addPlaceholder {
|
||||||
|
jsonStr = addEmptySchemaPlaceholder(jsonStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeKeywords removes all occurrences of specified keywords from the JSON schema.
|
||||||
|
func removeKeywords(jsonStr string, keywords []string) string {
|
||||||
|
for _, key := range keywords {
|
||||||
|
for _, p := range findPaths(jsonStr, key) {
|
||||||
|
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// removePlaceholderFields removes placeholder-only properties ("_" and "reason") and their required entries.
|
||||||
|
func removePlaceholderFields(jsonStr string) string {
|
||||||
|
// Remove "_" placeholder properties.
|
||||||
|
paths := findPaths(jsonStr, "_")
|
||||||
|
sortByDepth(paths)
|
||||||
|
for _, p := range paths {
|
||||||
|
if !strings.HasSuffix(p, ".properties._") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||||
|
parentPath := trimSuffix(p, ".properties._")
|
||||||
|
reqPath := joinPath(parentPath, "required")
|
||||||
|
req := gjson.Get(jsonStr, reqPath)
|
||||||
|
if req.IsArray() {
|
||||||
|
var filtered []string
|
||||||
|
for _, r := range req.Array() {
|
||||||
|
if r.String() != "_" {
|
||||||
|
filtered = append(filtered, r.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||||
|
} else {
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove placeholder-only "reason" objects.
|
||||||
|
reasonPaths := findPaths(jsonStr, "reason")
|
||||||
|
sortByDepth(reasonPaths)
|
||||||
|
for _, p := range reasonPaths {
|
||||||
|
if !strings.HasSuffix(p, ".properties.reason") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parentPath := trimSuffix(p, ".properties.reason")
|
||||||
|
props := gjson.Get(jsonStr, joinPath(parentPath, "properties"))
|
||||||
|
if !props.IsObject() || len(props.Map()) != 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
desc := gjson.Get(jsonStr, p+".description").String()
|
||||||
|
if desc != placeholderReasonDescription {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||||
|
reqPath := joinPath(parentPath, "required")
|
||||||
|
req := gjson.Get(jsonStr, reqPath)
|
||||||
|
if req.IsArray() {
|
||||||
|
var filtered []string
|
||||||
|
for _, r := range req.Array() {
|
||||||
|
if r.String() != "reason" {
|
||||||
|
filtered = append(filtered, r.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||||
|
} else {
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return jsonStr
|
return jsonStr
|
||||||
}
|
}
|
||||||
@@ -409,7 +506,7 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
|
|||||||
// Add placeholder "reason" property
|
// Add placeholder "reason" property
|
||||||
reasonPath := joinPath(propsPath, "reason")
|
reasonPath := joinPath(propsPath, "reason")
|
||||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
|
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
|
||||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
|
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", placeholderReasonDescription)
|
||||||
|
|
||||||
// Add to required array
|
// Add to required array
|
||||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
|
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
|
||||||
|
|||||||
@@ -170,7 +170,9 @@ func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenData, err := kiroauth.LoadKiroIDEToken()
|
// Use retry logic to handle file lock contention (e.g., Kiro IDE writing the file)
|
||||||
|
// This prevents "being used by another process" errors on Windows
|
||||||
|
tokenData, err := kiroauth.LoadKiroIDETokenWithRetry(10, 50*time.Millisecond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to load Kiro IDE token after change: %v", err)
|
log.Debugf("failed to load Kiro IDE token after change: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -167,6 +167,16 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
|||||||
"virtual_parent_id": primary.ID,
|
"virtual_parent_id": primary.ID,
|
||||||
"type": metadata["type"],
|
"type": metadata["type"],
|
||||||
}
|
}
|
||||||
|
if v, ok := metadata["disable_cooling"]; ok {
|
||||||
|
metadataCopy["disable_cooling"] = v
|
||||||
|
} else if v, ok := metadata["disable-cooling"]; ok {
|
||||||
|
metadataCopy["disable_cooling"] = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["request_retry"]; ok {
|
||||||
|
metadataCopy["request_retry"] = v
|
||||||
|
} else if v, ok := metadata["request-retry"]; ok {
|
||||||
|
metadataCopy["request_retry"] = v
|
||||||
|
}
|
||||||
proxy := strings.TrimSpace(primary.ProxyURL)
|
proxy := strings.TrimSpace(primary.ProxyURL)
|
||||||
if proxy != "" {
|
if proxy != "" {
|
||||||
metadataCopy["proxy_url"] = proxy
|
metadataCopy["proxy_url"] = proxy
|
||||||
|
|||||||
@@ -69,10 +69,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
|||||||
|
|
||||||
// Create a valid auth file
|
// Create a valid auth file
|
||||||
authData := map[string]any{
|
authData := map[string]any{
|
||||||
"type": "claude",
|
"type": "claude",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"proxy_url": "http://proxy.local",
|
"proxy_url": "http://proxy.local",
|
||||||
"prefix": "test-prefix",
|
"prefix": "test-prefix",
|
||||||
|
"disable_cooling": true,
|
||||||
|
"request_retry": 2,
|
||||||
}
|
}
|
||||||
data, _ := json.Marshal(authData)
|
data, _ := json.Marshal(authData)
|
||||||
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
||||||
@@ -108,6 +110,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
|||||||
if auths[0].ProxyURL != "http://proxy.local" {
|
if auths[0].ProxyURL != "http://proxy.local" {
|
||||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||||
}
|
}
|
||||||
|
if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v {
|
||||||
|
t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"])
|
||||||
|
}
|
||||||
|
if v, ok := auths[0].Metadata["request_retry"].(float64); !ok || int(v) != 2 {
|
||||||
|
t.Errorf("expected request_retry 2, got %v", auths[0].Metadata["request_retry"])
|
||||||
|
}
|
||||||
if auths[0].Status != coreauth.StatusActive {
|
if auths[0].Status != coreauth.StatusActive {
|
||||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||||
}
|
}
|
||||||
@@ -336,9 +344,11 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"project_id": "project-a, project-b, project-c",
|
"project_id": "project-a, project-b, project-c",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"type": "gemini",
|
"type": "gemini",
|
||||||
|
"request_retry": 2,
|
||||||
|
"disable_cooling": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||||
@@ -376,6 +386,12 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
|||||||
if v.ProxyURL != "http://proxy.local" {
|
if v.ProxyURL != "http://proxy.local" {
|
||||||
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
||||||
}
|
}
|
||||||
|
if vv, ok := v.Metadata["disable_cooling"].(bool); !ok || !vv {
|
||||||
|
t.Errorf("expected disable_cooling true, got %v", v.Metadata["disable_cooling"])
|
||||||
|
}
|
||||||
|
if vv, ok := v.Metadata["request_retry"].(int); !ok || vv != 2 {
|
||||||
|
t.Errorf("expected request_retry 2, got %v", v.Metadata["request_retry"])
|
||||||
|
}
|
||||||
if v.Attributes["runtime_only"] != "true" {
|
if v.Attributes["runtime_only"] != "true" {
|
||||||
t.Error("expected runtime_only=true")
|
t.Error("expected runtime_only=true")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -145,3 +145,111 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
w.clientsMutex.RUnlock()
|
w.clientsMutex.RUnlock()
|
||||||
return snapshotCoreAuths(cfg, w.authDir)
|
return snapshotCoreAuths(cfg, w.authDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NotifyTokenRefreshed 处理后台刷新器的 token 更新通知
|
||||||
|
// 当后台刷新器成功刷新 token 后调用此方法,更新内存中的 Auth 对象
|
||||||
|
// tokenID: token 文件名(如 kiro-xxx.json)
|
||||||
|
// accessToken: 新的 access token
|
||||||
|
// refreshToken: 新的 refresh token
|
||||||
|
// expiresAt: 新的过期时间
|
||||||
|
func (w *Watcher) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
defer w.clientsMutex.Unlock()
|
||||||
|
|
||||||
|
// 遍历 currentAuths,找到匹配的 Auth 并更新
|
||||||
|
updated := false
|
||||||
|
for id, auth := range w.currentAuths {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是 kiro 类型的 auth
|
||||||
|
authType, _ := auth.Metadata["type"].(string)
|
||||||
|
if authType != "kiro" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 多种匹配方式,解决不同来源的 auth 对象字段差异
|
||||||
|
matched := false
|
||||||
|
|
||||||
|
// 1. 通过 auth.ID 匹配(ID 可能包含文件名)
|
||||||
|
if !matched && auth.ID != "" {
|
||||||
|
if auth.ID == tokenID || strings.HasSuffix(auth.ID, "/"+tokenID) || strings.HasSuffix(auth.ID, "\\"+tokenID) {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
// ID 可能是 "kiro-xxx" 格式(无扩展名),tokenID 是 "kiro-xxx.json"
|
||||||
|
if !matched && strings.TrimSuffix(tokenID, ".json") == auth.ID {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 通过 auth.Attributes["path"] 匹配
|
||||||
|
if !matched && auth.Attributes != nil {
|
||||||
|
if authPath := auth.Attributes["path"]; authPath != "" {
|
||||||
|
// 提取文件名部分进行比较
|
||||||
|
pathBase := authPath
|
||||||
|
if idx := strings.LastIndexAny(authPath, "/\\"); idx >= 0 {
|
||||||
|
pathBase = authPath[idx+1:]
|
||||||
|
}
|
||||||
|
if pathBase == tokenID || strings.TrimSuffix(pathBase, ".json") == strings.TrimSuffix(tokenID, ".json") {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 通过 auth.FileName 匹配(原有逻辑)
|
||||||
|
if !matched && auth.FileName != "" {
|
||||||
|
if auth.FileName == tokenID || strings.HasSuffix(auth.FileName, "/"+tokenID) || strings.HasSuffix(auth.FileName, "\\"+tokenID) {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if matched {
|
||||||
|
// 更新内存中的 token
|
||||||
|
auth.Metadata["access_token"] = accessToken
|
||||||
|
auth.Metadata["refresh_token"] = refreshToken
|
||||||
|
auth.Metadata["expires_at"] = expiresAt
|
||||||
|
auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||||
|
auth.UpdatedAt = time.Now()
|
||||||
|
auth.LastRefreshedAt = time.Now()
|
||||||
|
|
||||||
|
log.Infof("watcher: updated in-memory auth for token %s (auth ID: %s)", tokenID, id)
|
||||||
|
updated = true
|
||||||
|
|
||||||
|
// 同时更新 runtimeAuths 中的副本(如果存在)
|
||||||
|
if w.runtimeAuths != nil {
|
||||||
|
if runtimeAuth, ok := w.runtimeAuths[id]; ok && runtimeAuth != nil {
|
||||||
|
if runtimeAuth.Metadata == nil {
|
||||||
|
runtimeAuth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
runtimeAuth.Metadata["access_token"] = accessToken
|
||||||
|
runtimeAuth.Metadata["refresh_token"] = refreshToken
|
||||||
|
runtimeAuth.Metadata["expires_at"] = expiresAt
|
||||||
|
runtimeAuth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||||
|
runtimeAuth.UpdatedAt = time.Now()
|
||||||
|
runtimeAuth.LastRefreshedAt = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送更新通知到 authQueue
|
||||||
|
if w.authQueue != nil {
|
||||||
|
go func(authClone *coreauth.Auth) {
|
||||||
|
update := AuthUpdate{
|
||||||
|
Action: AuthUpdateActionModify,
|
||||||
|
ID: authClone.ID,
|
||||||
|
Auth: authClone,
|
||||||
|
}
|
||||||
|
w.dispatchAuthUpdates([]AuthUpdate{update})
|
||||||
|
}(auth.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !updated {
|
||||||
|
log.Debugf("watcher: no matching auth found for token %s, will be picked up on next file scan", tokenID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -128,8 +128,23 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
|
|||||||
// Parameters:
|
// Parameters:
|
||||||
// - c: The Gin context for the request.
|
// - c: The Gin context for the request.
|
||||||
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
|
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
|
||||||
|
models := h.Models()
|
||||||
|
firstID := ""
|
||||||
|
lastID := ""
|
||||||
|
if len(models) > 0 {
|
||||||
|
if id, ok := models[0]["id"].(string); ok {
|
||||||
|
firstID = id
|
||||||
|
}
|
||||||
|
if id, ok := models[len(models)-1]["id"].(string); ok {
|
||||||
|
lastID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"data": h.Models(),
|
"data": models,
|
||||||
|
"has_more": false,
|
||||||
|
"first_id": firstID,
|
||||||
|
"last_id": lastID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -60,8 +60,12 @@ func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
|
|||||||
if !strings.HasPrefix(name, "models/") {
|
if !strings.HasPrefix(name, "models/") {
|
||||||
normalizedModel["name"] = "models/" + name
|
normalizedModel["name"] = "models/" + name
|
||||||
}
|
}
|
||||||
normalizedModel["displayName"] = name
|
if displayName, _ := normalizedModel["displayName"].(string); displayName == "" {
|
||||||
normalizedModel["description"] = name
|
normalizedModel["displayName"] = name
|
||||||
|
}
|
||||||
|
if description, _ := normalizedModel["description"].(string); description == "" {
|
||||||
|
normalizedModel["description"] = name
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
|
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
|
||||||
normalizedModel["supportedGenerationMethods"] = defaultMethods
|
normalizedModel["supportedGenerationMethods"] = defaultMethods
|
||||||
|
|||||||
@@ -386,6 +386,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
return nil, errMsg
|
return nil, errMsg
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
req := coreexecutor.Request{
|
req := coreexecutor.Request{
|
||||||
Model: normalizedModel,
|
Model: normalizedModel,
|
||||||
Payload: cloneBytes(rawJSON),
|
Payload: cloneBytes(rawJSON),
|
||||||
@@ -424,6 +425,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
return nil, errMsg
|
return nil, errMsg
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
req := coreexecutor.Request{
|
req := coreexecutor.Request{
|
||||||
Model: normalizedModel,
|
Model: normalizedModel,
|
||||||
Payload: cloneBytes(rawJSON),
|
Payload: cloneBytes(rawJSON),
|
||||||
@@ -465,6 +467,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
return nil, errChan
|
return nil, errChan
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
req := coreexecutor.Request{
|
req := coreexecutor.Request{
|
||||||
Model: normalizedModel,
|
Model: normalizedModel,
|
||||||
Payload: cloneBytes(rawJSON),
|
Payload: cloneBytes(rawJSON),
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ func resolveEndpointOverride(modelName, requestedEndpoint string) (string, bool)
|
|||||||
if modelName == "" {
|
if modelName == "" {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
info := registry.GetGlobalRegistry().GetModelInfo(modelName)
|
info := registry.GetGlobalRegistry().GetModelInfo(modelName, "")
|
||||||
if info == nil || len(info.SupportedEndpoints) == 0 {
|
if info == nil || len(info.SupportedEndpoints) == 0 {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
@@ -34,4 +34,4 @@ func endpointListContains(items []string, value string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user