mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-29 16:54:41 +00:00
Compare commits
177 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a14d19b92c | ||
|
|
8ae0c05ea6 | ||
|
|
8822f20d17 | ||
|
|
f0e5a5a367 | ||
|
|
f6dfea9357 | ||
|
|
cc8dc7f62c | ||
|
|
a3846ea513 | ||
|
|
8d44be858e | ||
|
|
0e6bb076e9 | ||
|
|
ac135fc7cb | ||
|
|
4e1d09809d | ||
|
|
9e855f8100 | ||
|
|
25680a8259 | ||
|
|
13c93e8cfd | ||
|
|
88aa1b9fd1 | ||
|
|
352cb98ff0 | ||
|
|
ac95e92829 | ||
|
|
8526c2da25 | ||
|
|
68a6cabf8b | ||
|
|
ac0e387da1 | ||
|
|
7fe1d102cb | ||
|
|
5850492a93 | ||
|
|
fdbd4041ca | ||
|
|
ebef1fae2a | ||
|
|
c51851689b | ||
|
|
419bf784ab | ||
|
|
4bbeb92e9a | ||
|
|
b436dad8bc | ||
|
|
6ae15d6c44 | ||
|
|
0468bde0d6 | ||
|
|
1d7329e797 | ||
|
|
48ffc4dee7 | ||
|
|
7ebd8f0c44 | ||
|
|
b680c146c1 | ||
|
|
7d6660d181 | ||
|
|
d8e3d4e2b6 | ||
|
|
d26ad8224d | ||
|
|
5c84d69d42 | ||
|
|
527e4b7f26 | ||
|
|
b48485b42b | ||
|
|
79009bb3d4 | ||
|
|
26fc611f86 | ||
|
|
b43743d4f1 | ||
|
|
179e5434b1 | ||
|
|
9f95b31158 | ||
|
|
5da07eae4c | ||
|
|
835ae178d4 | ||
|
|
c80ab8bf0d | ||
|
|
ce87714ef1 | ||
|
|
0452b869e8 | ||
|
|
d2e5857b82 | ||
|
|
f9b005f21f | ||
|
|
532107b4fa | ||
|
|
c44793789b | ||
|
|
4e99525279 | ||
|
|
7547d1d0b3 | ||
|
|
68934942d0 | ||
|
|
09fec34e1c | ||
|
|
9229708b6c | ||
|
|
914db94e79 | ||
|
|
660bd7eff5 | ||
|
|
b907d21851 | ||
|
|
d6cc976d1f | ||
|
|
8aa2cce8c5 | ||
|
|
bf9b2c49df | ||
|
|
77b42c6165 | ||
|
|
446150a747 | ||
|
|
1cbc4834e1 | ||
|
|
a8a5d03c33 | ||
|
|
76aa917882 | ||
|
|
6ac9b31e4e | ||
|
|
0ad3e8457f | ||
|
|
444a47ae63 | ||
|
|
725f4fdff4 | ||
|
|
c23e46f45d | ||
|
|
b148820c35 | ||
|
|
134f41496d | ||
|
|
c5838dd58d | ||
|
|
b6ca5ef7ce | ||
|
|
1ae994b4aa | ||
|
|
84e9793e61 | ||
|
|
32e64dacfd | ||
|
|
cc1d8f6629 | ||
|
|
5446cd2b02 | ||
|
|
8de0885b7d | ||
|
|
16243f18fd | ||
|
|
a6ce5f36e6 | ||
|
|
e73cf42e28 | ||
|
|
b45343e812 | ||
|
|
8599b1560e | ||
|
|
8bde8c37c0 | ||
|
|
82df5bf88a | ||
|
|
acb1066de8 | ||
|
|
27c68f5bb2 | ||
|
|
68dd2bfe82 | ||
|
|
65a87815e7 | ||
|
|
b80793ca82 | ||
|
|
601550f238 | ||
|
|
41b1cf2273 | ||
|
|
2baf35b3ef | ||
|
|
846e75b893 | ||
|
|
fc0257d6d9 | ||
|
|
f3c164d345 | ||
|
|
4040b1e766 | ||
|
|
3b4f9f43db | ||
|
|
37a09ecb23 | ||
|
|
0da34d3c2d | ||
|
|
74bf7eda8f | ||
|
|
9032042cfa | ||
|
|
030bf5e6c7 | ||
|
|
d3100085b0 | ||
|
|
f481d25133 | ||
|
|
8c6c90da74 | ||
|
|
24bcfd9c03 | ||
|
|
816fb4c5da | ||
|
|
c1bb77c7c9 | ||
|
|
6bcac3a55a | ||
|
|
fc346f4537 | ||
|
|
43e531a3b6 | ||
|
|
d24ea4ce2a | ||
|
|
2c30c981ae | ||
|
|
aa1da8a858 | ||
|
|
f1e9a787d7 | ||
|
|
4eeec297de | ||
|
|
77cc4ce3a0 | ||
|
|
37dfea1d3f | ||
|
|
e6626c672a | ||
|
|
c66cb0afd2 | ||
|
|
fb48eee973 | ||
|
|
bb44e5ec44 | ||
|
|
c785c1a3ca | ||
|
|
514ae341c8 | ||
|
|
0659ffab75 | ||
|
|
8ce07f38dd | ||
|
|
7cb398d167 | ||
|
|
c3e12c5e58 | ||
|
|
1825fc7503 | ||
|
|
48732ba05e | ||
|
|
acf483c9e6 | ||
|
|
3b3e0d1141 | ||
|
|
7acd428507 | ||
|
|
0aaf177640 | ||
|
|
450d1227bd | ||
|
|
492b9c46f0 | ||
|
|
6e634fe3f9 | ||
|
|
4e26182d14 | ||
|
|
8f97a5f77c | ||
|
|
eb7571936c | ||
|
|
5382764d8a | ||
|
|
49c8ec69d0 | ||
|
|
3b421c8181 | ||
|
|
2a4d3e60f3 | ||
|
|
8b5af2ab84 | ||
|
|
d887716ebd | ||
|
|
5dc1848466 | ||
|
|
9491517b26 | ||
|
|
9370b5bd04 | ||
|
|
abb51a0d93 | ||
|
|
c8d809131b | ||
|
|
dd71c73a9f | ||
|
|
afc8a0f9be | ||
|
|
a99522224f | ||
|
|
f5d46b9ca2 | ||
|
|
d693d7993b | ||
|
|
5936f9895c | ||
|
|
0cbfe7f457 | ||
|
|
b9ae4ab803 | ||
|
|
65debb874f | ||
|
|
3caadac003 | ||
|
|
6a9e3a6b84 | ||
|
|
269972440a | ||
|
|
cce13e6ad2 | ||
|
|
8a565dcad8 | ||
|
|
d536110404 | ||
|
|
48e957ddff | ||
|
|
94563d622c | ||
|
|
2615f489d6 |
@@ -31,6 +31,7 @@ bin/*
|
|||||||
.agent/*
|
.agent/*
|
||||||
.agents/*
|
.agents/*
|
||||||
.opencode/*
|
.opencode/*
|
||||||
|
.idea/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
_bmad-output/*
|
_bmad-output/*
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -44,6 +44,7 @@ GEMINI.md
|
|||||||
.agents/*
|
.agents/*
|
||||||
.agents/*
|
.agents/*
|
||||||
.opencode/*
|
.opencode/*
|
||||||
|
.idea/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
_bmad-output/*
|
_bmad-output/*
|
||||||
|
|||||||
58
README.md
58
README.md
@@ -10,23 +10,59 @@ The Plus release stays in lockstep with the mainline features.
|
|||||||
|
|
||||||
## Differences from the Mainline
|
## Differences from the Mainline
|
||||||
|
|
||||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
|
||||||
|
|
||||||
## New Features (Plus Enhanced)
|
## New Features (Plus Enhanced)
|
||||||
|
|
||||||
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
|
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||||
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
|
|
||||||
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
|
|
||||||
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
|
|
||||||
- **Device Fingerprint**: Device fingerprint generation for enhanced security
|
|
||||||
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
|
|
||||||
- **Usage Checker**: Real-time usage monitoring and quota management
|
|
||||||
- **Model Converter**: Unified model name conversion across providers
|
|
||||||
- **UTF-8 Stream Processing**: Improved streaming response handling
|
|
||||||
|
|
||||||
## Kiro Authentication
|
## Kiro Authentication
|
||||||
|
|
||||||
|
### CLI Login
|
||||||
|
|
||||||
|
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||||
|
|
||||||
|
**AWS Builder ID** (recommended):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Device code flow
|
||||||
|
./CLIProxyAPI --kiro-aws-login
|
||||||
|
|
||||||
|
# Authorization code flow
|
||||||
|
./CLIProxyAPI --kiro-aws-authcode
|
||||||
|
```
|
||||||
|
|
||||||
|
**Import token from Kiro IDE:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI --kiro-import
|
||||||
|
```
|
||||||
|
|
||||||
|
To get a token from Kiro IDE:
|
||||||
|
|
||||||
|
1. Open Kiro IDE and login with Google (or GitHub)
|
||||||
|
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
||||||
|
3. Run: `./CLIProxyAPI --kiro-import`
|
||||||
|
|
||||||
|
**AWS IAM Identity Center (IDC):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||||
|
|
||||||
|
# Specify region
|
||||||
|
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||||
|
```
|
||||||
|
|
||||||
|
**Additional flags:**
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `--no-browser` | Don't open browser automatically, print URL instead |
|
||||||
|
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
||||||
|
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
||||||
|
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
||||||
|
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
||||||
|
|
||||||
### Web-based OAuth Login
|
### Web-based OAuth Login
|
||||||
|
|
||||||
Access the Kiro OAuth web interface at:
|
Access the Kiro OAuth web interface at:
|
||||||
|
|||||||
60
README_CN.md
60
README_CN.md
@@ -10,22 +10,58 @@
|
|||||||
|
|
||||||
## 与主线版本版本差异
|
## 与主线版本版本差异
|
||||||
|
|
||||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
|
||||||
|
|
||||||
## 新增功能 (Plus 增强版)
|
## 新增功能 (Plus 增强版)
|
||||||
|
|
||||||
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
|
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。
|
||||||
- **请求限流器**: 内置请求限流,防止 API 滥用
|
|
||||||
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
|
|
||||||
- **监控指标**: 请求指标收集,用于监控和调试
|
|
||||||
- **设备指纹**: 设备指纹生成,增强安全性
|
|
||||||
- **冷却管理**: 智能冷却机制,应对 API 速率限制
|
|
||||||
- **用量检查器**: 实时用量监控和配额管理
|
|
||||||
- **模型转换器**: 跨供应商的统一模型名称转换
|
|
||||||
- **UTF-8 流处理**: 改进的流式响应处理
|
|
||||||
|
|
||||||
## Kiro 认证
|
智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
||||||
|
|
||||||
|
### 命令行登录
|
||||||
|
|
||||||
|
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
||||||
|
|
||||||
|
**AWS Builder ID**(推荐):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 设备码流程
|
||||||
|
./CLIProxyAPI --kiro-aws-login
|
||||||
|
|
||||||
|
# 授权码流程
|
||||||
|
./CLIProxyAPI --kiro-aws-authcode
|
||||||
|
```
|
||||||
|
|
||||||
|
**从 Kiro IDE 导入令牌:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI --kiro-import
|
||||||
|
```
|
||||||
|
|
||||||
|
获取令牌步骤:
|
||||||
|
|
||||||
|
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
||||||
|
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
||||||
|
3. 运行:`./CLIProxyAPI --kiro-import`
|
||||||
|
|
||||||
|
**AWS IAM Identity Center (IDC):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||||
|
|
||||||
|
# 指定区域
|
||||||
|
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||||
|
```
|
||||||
|
|
||||||
|
**附加参数:**
|
||||||
|
|
||||||
|
| 参数 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
||||||
|
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
||||||
|
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
||||||
|
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
||||||
|
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
||||||
|
|
||||||
### 网页端 OAuth 登录
|
### 网页端 OAuth 登录
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ func main() {
|
|||||||
// Command-line flags to control the application's behavior.
|
// Command-line flags to control the application's behavior.
|
||||||
var login bool
|
var login bool
|
||||||
var codexLogin bool
|
var codexLogin bool
|
||||||
|
var codexDeviceLogin bool
|
||||||
var claudeLogin bool
|
var claudeLogin bool
|
||||||
var qwenLogin bool
|
var qwenLogin bool
|
||||||
var kiloLogin bool
|
var kiloLogin bool
|
||||||
@@ -86,6 +87,10 @@ func main() {
|
|||||||
var kiroAWSLogin bool
|
var kiroAWSLogin bool
|
||||||
var kiroAWSAuthCode bool
|
var kiroAWSAuthCode bool
|
||||||
var kiroImport bool
|
var kiroImport bool
|
||||||
|
var kiroIDCLogin bool
|
||||||
|
var kiroIDCStartURL string
|
||||||
|
var kiroIDCRegion string
|
||||||
|
var kiroIDCFlow string
|
||||||
var githubCopilotLogin bool
|
var githubCopilotLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
@@ -99,6 +104,7 @@ func main() {
|
|||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||||
|
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||||
@@ -115,6 +121,10 @@ func main() {
|
|||||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||||
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
|
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
|
||||||
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
||||||
|
flag.BoolVar(&kiroIDCLogin, "kiro-idc-login", false, "Login to Kiro using IAM Identity Center (IDC)")
|
||||||
|
flag.StringVar(&kiroIDCStartURL, "kiro-idc-start-url", "", "IDC start URL (required with --kiro-idc-login)")
|
||||||
|
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
||||||
|
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
||||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
@@ -502,6 +512,9 @@ func main() {
|
|||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
// Handle Codex login
|
// Handle Codex login
|
||||||
cmd.DoCodexLogin(cfg, options)
|
cmd.DoCodexLogin(cfg, options)
|
||||||
|
} else if codexDeviceLogin {
|
||||||
|
// Handle Codex device-code login
|
||||||
|
cmd.DoCodexDeviceLogin(cfg, options)
|
||||||
} else if claudeLogin {
|
} else if claudeLogin {
|
||||||
// Handle Claude login
|
// Handle Claude login
|
||||||
cmd.DoClaudeLogin(cfg, options)
|
cmd.DoClaudeLogin(cfg, options)
|
||||||
@@ -521,24 +534,34 @@ func main() {
|
|||||||
// Note: This config mutation is safe - auth commands exit after completion
|
// Note: This config mutation is safe - auth commands exit after completion
|
||||||
// and don't share config with StartService (which is in the else branch)
|
// and don't share config with StartService (which is in the else branch)
|
||||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
kiro.InitFingerprintConfig(cfg)
|
||||||
cmd.DoKiroLogin(cfg, options)
|
cmd.DoKiroLogin(cfg, options)
|
||||||
} else if kiroGoogleLogin {
|
} else if kiroGoogleLogin {
|
||||||
// For Kiro auth, default to incognito mode for multi-account support
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
// Users can explicitly override with --no-incognito
|
// Users can explicitly override with --no-incognito
|
||||||
// Note: This config mutation is safe - auth commands exit after completion
|
// Note: This config mutation is safe - auth commands exit after completion
|
||||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
kiro.InitFingerprintConfig(cfg)
|
||||||
cmd.DoKiroGoogleLogin(cfg, options)
|
cmd.DoKiroGoogleLogin(cfg, options)
|
||||||
} else if kiroAWSLogin {
|
} else if kiroAWSLogin {
|
||||||
// For Kiro auth, default to incognito mode for multi-account support
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
// Users can explicitly override with --no-incognito
|
// Users can explicitly override with --no-incognito
|
||||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
kiro.InitFingerprintConfig(cfg)
|
||||||
cmd.DoKiroAWSLogin(cfg, options)
|
cmd.DoKiroAWSLogin(cfg, options)
|
||||||
} else if kiroAWSAuthCode {
|
} else if kiroAWSAuthCode {
|
||||||
// For Kiro auth with authorization code flow (better UX)
|
// For Kiro auth with authorization code flow (better UX)
|
||||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
kiro.InitFingerprintConfig(cfg)
|
||||||
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
|
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
|
||||||
} else if kiroImport {
|
} else if kiroImport {
|
||||||
|
kiro.InitFingerprintConfig(cfg)
|
||||||
cmd.DoKiroImport(cfg, options)
|
cmd.DoKiroImport(cfg, options)
|
||||||
|
} else if kiroIDCLogin {
|
||||||
|
// For Kiro IDC auth, default to incognito mode for multi-account support
|
||||||
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
kiro.InitFingerprintConfig(cfg)
|
||||||
|
cmd.DoKiroIDCLogin(cfg, options, kiroIDCStartURL, kiroIDCRegion, kiroIDCFlow)
|
||||||
} else {
|
} else {
|
||||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||||
if isCloudDeploy && !configFileExists {
|
if isCloudDeploy && !configFileExists {
|
||||||
|
|||||||
@@ -80,6 +80,10 @@ passthrough-headers: false
|
|||||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||||
request-retry: 3
|
request-retry: 3
|
||||||
|
|
||||||
|
# Maximum number of different credentials to try for one failed request.
|
||||||
|
# Set to 0 to keep legacy behavior (try all available credentials).
|
||||||
|
max-retry-credentials: 0
|
||||||
|
|
||||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||||
max-retry-interval: 30
|
max-retry-interval: 30
|
||||||
|
|
||||||
@@ -179,6 +183,8 @@ nonstream-keepalive-interval: 0
|
|||||||
#kiro:
|
#kiro:
|
||||||
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
|
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
|
||||||
# agent-task-type: "" # optional: "vibe" or empty (API default)
|
# agent-task-type: "" # optional: "vibe" or empty (API default)
|
||||||
|
# start-url: "https://your-company.awsapps.com/start" # optional: IDC start URL (preset for login)
|
||||||
|
# region: "us-east-1" # optional: OIDC region for IDC login and token refresh
|
||||||
# - access-token: "aoaAAAAA..." # or provide tokens directly
|
# - access-token: "aoaAAAAA..." # or provide tokens directly
|
||||||
# refresh-token: "aorAAAAA..."
|
# refresh-token: "aorAAAAA..."
|
||||||
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
||||||
@@ -227,6 +233,9 @@ nonstream-keepalive-interval: 0
|
|||||||
# alias: "vertex-flash" # client-visible alias
|
# alias: "vertex-flash" # client-visible alias
|
||||||
# - name: "gemini-2.5-pro"
|
# - name: "gemini-2.5-pro"
|
||||||
# alias: "vertex-pro"
|
# alias: "vertex-pro"
|
||||||
|
# excluded-models: # optional: models to exclude from listing
|
||||||
|
# - "imagen-3.0-generate-002"
|
||||||
|
# - "imagen-*"
|
||||||
|
|
||||||
# Amp Integration
|
# Amp Integration
|
||||||
# ampcode:
|
# ampcode:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -48,14 +49,11 @@ import (
|
|||||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
anthropicCallbackPort = 54545
|
anthropicCallbackPort = 54545
|
||||||
geminiCallbackPort = 8085
|
geminiCallbackPort = 8085
|
||||||
codexCallbackPort = 1455
|
codexCallbackPort = 1455
|
||||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
geminiCLIVersion = "v1internal"
|
geminiCLIVersion = "v1internal"
|
||||||
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
|
|
||||||
geminiCLIApiClient = "gl-node/22.17.0"
|
|
||||||
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type callbackForwarder struct {
|
type callbackForwarder struct {
|
||||||
@@ -195,17 +193,6 @@ func startCallbackForwarder(port int, provider, targetBase string) (*callbackFor
|
|||||||
return forwarder, nil
|
return forwarder, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func stopCallbackForwarder(port int) {
|
|
||||||
callbackForwardersMu.Lock()
|
|
||||||
forwarder := callbackForwarders[port]
|
|
||||||
if forwarder != nil {
|
|
||||||
delete(callbackForwarders, port)
|
|
||||||
}
|
|
||||||
callbackForwardersMu.Unlock()
|
|
||||||
|
|
||||||
stopForwarderInstance(port, forwarder)
|
|
||||||
}
|
|
||||||
|
|
||||||
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
|
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||||
if forwarder == nil {
|
if forwarder == nil {
|
||||||
return
|
return
|
||||||
@@ -412,6 +399,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
|||||||
if !auth.LastRefreshedAt.IsZero() {
|
if !auth.LastRefreshedAt.IsZero() {
|
||||||
entry["last_refresh"] = auth.LastRefreshedAt
|
entry["last_refresh"] = auth.LastRefreshedAt
|
||||||
}
|
}
|
||||||
|
if !auth.NextRetryAfter.IsZero() {
|
||||||
|
entry["next_retry_after"] = auth.NextRetryAfter
|
||||||
|
}
|
||||||
if path != "" {
|
if path != "" {
|
||||||
entry["path"] = path
|
entry["path"] = path
|
||||||
entry["source"] = "file"
|
entry["source"] = "file"
|
||||||
@@ -644,44 +634,85 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
|||||||
c.JSON(400, gin.H{"error": "invalid name"})
|
c.JSON(400, gin.H{"error": "invalid name"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
full := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
|
||||||
if !filepath.IsAbs(full) {
|
targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||||
if abs, errAbs := filepath.Abs(full); errAbs == nil {
|
targetID := ""
|
||||||
full = abs
|
if targetAuth := h.findAuthForDelete(name); targetAuth != nil {
|
||||||
|
targetID = strings.TrimSpace(targetAuth.ID)
|
||||||
|
if path := strings.TrimSpace(authAttribute(targetAuth, "path")); path != "" {
|
||||||
|
targetPath = path
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := os.Remove(full); err != nil {
|
if !filepath.IsAbs(targetPath) {
|
||||||
if os.IsNotExist(err) {
|
if abs, errAbs := filepath.Abs(targetPath); errAbs == nil {
|
||||||
|
targetPath = abs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errRemove := os.Remove(targetPath); errRemove != nil {
|
||||||
|
if os.IsNotExist(errRemove) {
|
||||||
c.JSON(404, gin.H{"error": "file not found"})
|
c.JSON(404, gin.H{"error": "file not found"})
|
||||||
} else {
|
} else {
|
||||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", err)})
|
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.deleteTokenRecord(ctx, full); err != nil {
|
if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil {
|
||||||
c.JSON(500, gin.H{"error": err.Error()})
|
c.JSON(500, gin.H{"error": errDeleteRecord.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.disableAuth(ctx, full)
|
if targetID != "" {
|
||||||
|
h.disableAuth(ctx, targetID)
|
||||||
|
} else {
|
||||||
|
h.disableAuth(ctx, targetPath)
|
||||||
|
}
|
||||||
c.JSON(200, gin.H{"status": "ok"})
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) findAuthForDelete(name string) *coreauth.Auth {
|
||||||
|
if h == nil || h.authManager == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if auth, ok := h.authManager.GetByID(name); ok {
|
||||||
|
return auth
|
||||||
|
}
|
||||||
|
auths := h.authManager.List()
|
||||||
|
for _, auth := range auths {
|
||||||
|
if auth == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(auth.FileName) == name {
|
||||||
|
return auth
|
||||||
|
}
|
||||||
|
if filepath.Base(strings.TrimSpace(authAttribute(auth, "path"))) == name {
|
||||||
|
return auth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) authIDForPath(path string) string {
|
func (h *Handler) authIDForPath(path string) string {
|
||||||
path = strings.TrimSpace(path)
|
path = strings.TrimSpace(path)
|
||||||
if path == "" {
|
if path == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if h == nil || h.cfg == nil {
|
id := path
|
||||||
return path
|
if h != nil && h.cfg != nil {
|
||||||
|
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||||
|
if authDir != "" {
|
||||||
|
if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" {
|
||||||
|
id = rel
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
|
||||||
if authDir == "" {
|
if runtime.GOOS == "windows" {
|
||||||
return path
|
id = strings.ToLower(id)
|
||||||
}
|
}
|
||||||
if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" {
|
return id
|
||||||
return rel
|
|
||||||
}
|
|
||||||
return path
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
|
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
|
||||||
@@ -899,10 +930,19 @@ func (h *Handler) disableAuth(ctx context.Context, id string) {
|
|||||||
if h == nil || h.authManager == nil {
|
if h == nil || h.authManager == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authID := h.authIDForPath(id)
|
id = strings.TrimSpace(id)
|
||||||
if authID == "" {
|
if id == "" {
|
||||||
authID = strings.TrimSpace(id)
|
return
|
||||||
}
|
}
|
||||||
|
if auth, ok := h.authManager.GetByID(id); ok {
|
||||||
|
auth.Disabled = true
|
||||||
|
auth.Status = coreauth.StatusDisabled
|
||||||
|
auth.StatusMessage = "removed via management API"
|
||||||
|
auth.UpdatedAt = time.Now()
|
||||||
|
_, _ = h.authManager.Update(ctx, auth)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authID := h.authIDForPath(id)
|
||||||
if authID == "" {
|
if authID == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -951,11 +991,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
|||||||
if store == nil {
|
if store == nil {
|
||||||
return "", fmt.Errorf("token store unavailable")
|
return "", fmt.Errorf("token store unavailable")
|
||||||
}
|
}
|
||||||
|
if h.postAuthHook != nil {
|
||||||
|
if err := h.postAuthHook(ctx, record); err != nil {
|
||||||
|
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return store.Save(ctx, record)
|
return store.Save(ctx, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Claude authentication...")
|
fmt.Println("Initializing Claude authentication...")
|
||||||
|
|
||||||
@@ -1100,6 +1146,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||||
|
|
||||||
@@ -1358,6 +1405,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Codex authentication...")
|
fmt.Println("Initializing Codex authentication...")
|
||||||
|
|
||||||
@@ -1503,6 +1551,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Antigravity authentication...")
|
fmt.Println("Initializing Antigravity authentication...")
|
||||||
|
|
||||||
@@ -1667,6 +1716,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Qwen authentication...")
|
fmt.Println("Initializing Qwen authentication...")
|
||||||
|
|
||||||
@@ -1722,6 +1772,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Kimi authentication...")
|
fmt.Println("Initializing Kimi authentication...")
|
||||||
|
|
||||||
@@ -1798,6 +1849,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing iFlow authentication...")
|
fmt.Println("Initializing iFlow authentication...")
|
||||||
|
|
||||||
@@ -1917,8 +1969,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
|
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
// Initialize Copilot auth service
|
// Initialize Copilot auth service
|
||||||
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
|
|
||||||
// Assuming copilot package is imported as "copilot"
|
|
||||||
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
|
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
|
||||||
|
|
||||||
// Initiate device flow
|
// Initiate device flow
|
||||||
@@ -1932,7 +1982,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
authURL := deviceCode.VerificationURI
|
authURL := deviceCode.VerificationURI
|
||||||
userCode := deviceCode.UserCode
|
userCode := deviceCode.UserCode
|
||||||
|
|
||||||
RegisterOAuthSession(state, "github")
|
RegisterOAuthSession(state, "github-copilot")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
|
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
|
||||||
@@ -1944,9 +1994,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||||
if errUser != nil {
|
if errUser != nil {
|
||||||
log.Warnf("Failed to fetch user info: %v", errUser)
|
log.Warnf("Failed to fetch user info: %v", errUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
username := userInfo.Login
|
||||||
|
if username == "" {
|
||||||
username = "github-user"
|
username = "github-user"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1955,18 +2009,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
TokenType: tokenData.TokenType,
|
TokenType: tokenData.TokenType,
|
||||||
Scope: tokenData.Scope,
|
Scope: tokenData.Scope,
|
||||||
Username: username,
|
Username: username,
|
||||||
|
Email: userInfo.Email,
|
||||||
|
Name: userInfo.Name,
|
||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := fmt.Sprintf("github-%s.json", username)
|
fileName := fmt.Sprintf("github-copilot-%s.json", username)
|
||||||
|
label := userInfo.Email
|
||||||
|
if label == "" {
|
||||||
|
label = username
|
||||||
|
}
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fileName,
|
ID: fileName,
|
||||||
Provider: "github",
|
Provider: "github-copilot",
|
||||||
|
Label: label,
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"email": username,
|
"email": userInfo.Email,
|
||||||
"username": username,
|
"username": username,
|
||||||
|
"name": userInfo.Name,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1980,7 +2042,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
fmt.Println("You can now use GitHub Copilot services through this CLI")
|
fmt.Println("You can now use GitHub Copilot services through this CLI")
|
||||||
CompleteOAuthSession(state)
|
CompleteOAuthSession(state)
|
||||||
CompleteOAuthSessionsByProvider("github")
|
CompleteOAuthSessionsByProvider("github-copilot")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
@@ -2359,9 +2421,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
|
|||||||
return fmt.Errorf("create request: %w", errRequest)
|
return fmt.Errorf("create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
|
||||||
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
@@ -2431,7 +2491,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||||
@@ -2452,7 +2512,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
resp, errDo = httpClient.Do(req)
|
resp, errDo = httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||||
@@ -2521,6 +2581,15 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PopulateAuthContext extracts request info and adds it to the context
|
||||||
|
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||||
|
info := &coreauth.RequestInfo{
|
||||||
|
Query: c.Request.URL.Query(),
|
||||||
|
Headers: c.Request.Header,
|
||||||
|
}
|
||||||
|
return coreauth.WithRequestInfo(ctx, info)
|
||||||
|
}
|
||||||
|
|
||||||
const kiroCallbackPort = 9876
|
const kiroCallbackPort = 9876
|
||||||
|
|
||||||
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||||
@@ -2657,6 +2726,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
isWebUI := isWebUIRequest(c)
|
isWebUI := isWebUIRequest(c)
|
||||||
|
var forwarder *callbackForwarder
|
||||||
if isWebUI {
|
if isWebUI {
|
||||||
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
||||||
if errTarget != nil {
|
if errTarget != nil {
|
||||||
@@ -2664,7 +2734,8 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
var errStart error
|
||||||
|
if forwarder, errStart = startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||||
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
return
|
return
|
||||||
@@ -2673,7 +2744,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if isWebUI {
|
if isWebUI {
|
||||||
defer stopCallbackForwarder(kiroCallbackPort)
|
defer stopCallbackForwarderInstance(kiroCallbackPort, forwarder)
|
||||||
}
|
}
|
||||||
|
|
||||||
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
||||||
@@ -2876,7 +2947,7 @@ func (h *Handler) RequestKiloToken(c *gin.Context) {
|
|||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"email": status.UserEmail,
|
"email": status.UserEmail,
|
||||||
"organization_id": orgID,
|
"organization_id": orgID,
|
||||||
"model": defaults.Model,
|
"model": defaults.Model,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeleteAuthFile_UsesAuthPathFromManager(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tempDir, "auth")
|
||||||
|
externalDir := filepath.Join(tempDir, "external")
|
||||||
|
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||||
|
}
|
||||||
|
if errMkdirExternal := os.MkdirAll(externalDir, 0o700); errMkdirExternal != nil {
|
||||||
|
t.Fatalf("failed to create external dir: %v", errMkdirExternal)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := "codex-user@example.com-plus.json"
|
||||||
|
shadowPath := filepath.Join(authDir, fileName)
|
||||||
|
realPath := filepath.Join(externalDir, fileName)
|
||||||
|
if errWriteShadow := os.WriteFile(shadowPath, []byte(`{"type":"codex","email":"shadow@example.com"}`), 0o600); errWriteShadow != nil {
|
||||||
|
t.Fatalf("failed to write shadow file: %v", errWriteShadow)
|
||||||
|
}
|
||||||
|
if errWriteReal := os.WriteFile(realPath, []byte(`{"type":"codex","email":"real@example.com"}`), 0o600); errWriteReal != nil {
|
||||||
|
t.Fatalf("failed to write real file: %v", errWriteReal)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: "legacy/" + fileName,
|
||||||
|
FileName: fileName,
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusError,
|
||||||
|
Unavailable: true,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": realPath,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "codex",
|
||||||
|
"email": "real@example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
h.tokenStore = &memoryAuthStore{}
|
||||||
|
|
||||||
|
deleteRec := httptest.NewRecorder()
|
||||||
|
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||||
|
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||||
|
deleteCtx.Request = deleteReq
|
||||||
|
h.DeleteAuthFile(deleteCtx)
|
||||||
|
|
||||||
|
if deleteRec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||||
|
}
|
||||||
|
if _, errStatReal := os.Stat(realPath); !os.IsNotExist(errStatReal) {
|
||||||
|
t.Fatalf("expected managed auth file to be removed, stat err: %v", errStatReal)
|
||||||
|
}
|
||||||
|
if _, errStatShadow := os.Stat(shadowPath); errStatShadow != nil {
|
||||||
|
t.Fatalf("expected shadow auth file to remain, stat err: %v", errStatShadow)
|
||||||
|
}
|
||||||
|
|
||||||
|
listRec := httptest.NewRecorder()
|
||||||
|
listCtx, _ := gin.CreateTestContext(listRec)
|
||||||
|
listReq := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
|
||||||
|
listCtx.Request = listReq
|
||||||
|
h.ListAuthFiles(listCtx)
|
||||||
|
|
||||||
|
if listRec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, listRec.Code, listRec.Body.String())
|
||||||
|
}
|
||||||
|
var listPayload map[string]any
|
||||||
|
if errUnmarshal := json.Unmarshal(listRec.Body.Bytes(), &listPayload); errUnmarshal != nil {
|
||||||
|
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
|
||||||
|
}
|
||||||
|
filesRaw, ok := listPayload["files"].([]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected files array, payload: %#v", listPayload)
|
||||||
|
}
|
||||||
|
if len(filesRaw) != 0 {
|
||||||
|
t.Fatalf("expected removed auth to be hidden from list, got %d entries", len(filesRaw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAuthFile_FallbackToAuthDirPath(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
fileName := "fallback-user.json"
|
||||||
|
filePath := filepath.Join(authDir, fileName)
|
||||||
|
if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex"}`), 0o600); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", errWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
h.tokenStore = &memoryAuthStore{}
|
||||||
|
|
||||||
|
deleteRec := httptest.NewRecorder()
|
||||||
|
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||||
|
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||||
|
deleteCtx.Request = deleteReq
|
||||||
|
h.DeleteAuthFile(deleteCtx)
|
||||||
|
|
||||||
|
if deleteRec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||||
|
}
|
||||||
|
if _, errStat := os.Stat(filePath); !os.IsNotExist(errStat) {
|
||||||
|
t.Fatalf("expected auth file to be removed from auth dir, stat err: %v", errStat)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -516,12 +516,13 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
||||||
type vertexCompatPatch struct {
|
type vertexCompatPatch struct {
|
||||||
APIKey *string `json:"api-key"`
|
APIKey *string `json:"api-key"`
|
||||||
Prefix *string `json:"prefix"`
|
Prefix *string `json:"prefix"`
|
||||||
BaseURL *string `json:"base-url"`
|
BaseURL *string `json:"base-url"`
|
||||||
ProxyURL *string `json:"proxy-url"`
|
ProxyURL *string `json:"proxy-url"`
|
||||||
Headers *map[string]string `json:"headers"`
|
Headers *map[string]string `json:"headers"`
|
||||||
Models *[]config.VertexCompatModel `json:"models"`
|
Models *[]config.VertexCompatModel `json:"models"`
|
||||||
|
ExcludedModels *[]string `json:"excluded-models"`
|
||||||
}
|
}
|
||||||
var body struct {
|
var body struct {
|
||||||
Index *int `json:"index"`
|
Index *int `json:"index"`
|
||||||
@@ -585,6 +586,9 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
|||||||
if body.Value.Models != nil {
|
if body.Value.Models != nil {
|
||||||
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
|
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
|
||||||
}
|
}
|
||||||
|
if body.Value.ExcludedModels != nil {
|
||||||
|
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||||
|
}
|
||||||
normalizeVertexCompatKey(&entry)
|
normalizeVertexCompatKey(&entry)
|
||||||
h.cfg.VertexCompatAPIKey[targetIndex] = entry
|
h.cfg.VertexCompatAPIKey[targetIndex] = entry
|
||||||
h.cfg.SanitizeVertexCompatKeys()
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
@@ -1029,6 +1033,7 @@ func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
|
|||||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||||
|
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
if len(entry.Models) == 0 {
|
if len(entry.Models) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type Handler struct {
|
|||||||
allowRemoteOverride bool
|
allowRemoteOverride bool
|
||||||
envSecret string
|
envSecret string
|
||||||
logDir string
|
logDir string
|
||||||
|
postAuthHook coreauth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new management handler instance.
|
// NewHandler creates a new management handler instance.
|
||||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
|||||||
h.logDir = dir
|
h.logDir = dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||||
|
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||||
|
h.postAuthHook = hook
|
||||||
|
}
|
||||||
|
|
||||||
// Middleware enforces access control for management endpoints.
|
// Middleware enforces access control for management endpoints.
|
||||||
// All requests (local and remote) require a valid management key.
|
// All requests (local and remote) require a valid management key.
|
||||||
// Additionally, remote access requires allow-remote-management=true.
|
// Additionally, remote access requires allow-remote-management=true.
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -77,6 +78,9 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
req.Header.Del("X-Api-Key")
|
req.Header.Del("X-Api-Key")
|
||||||
req.Header.Del("X-Goog-Api-Key")
|
req.Header.Del("X-Goog-Api-Key")
|
||||||
|
|
||||||
|
// Remove proxy, client identity, and browser fingerprint headers
|
||||||
|
misc.ScrubProxyAndFingerprintHeaders(req)
|
||||||
|
|
||||||
// Remove query-based credentials if they match the authenticated client API key.
|
// Remove query-based credentials if they match the authenticated client API key.
|
||||||
// This prevents leaking client auth material to the Amp upstream while avoiding
|
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||||
// breaking unrelated upstream query parameters.
|
// breaking unrelated upstream query parameters.
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ type serverOptionConfig struct {
|
|||||||
keepAliveEnabled bool
|
keepAliveEnabled bool
|
||||||
keepAliveTimeout time.Duration
|
keepAliveTimeout time.Duration
|
||||||
keepAliveOnTimeout func()
|
keepAliveOnTimeout func()
|
||||||
|
postAuthHook auth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerOption customises HTTP server construction.
|
// ServerOption customises HTTP server construction.
|
||||||
@@ -59,10 +60,8 @@ type ServerOption func(*serverOptionConfig)
|
|||||||
|
|
||||||
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
||||||
configDir := filepath.Dir(configPath)
|
configDir := filepath.Dir(configPath)
|
||||||
if base := util.WritablePath(); base != "" {
|
logsDir := logging.ResolveLogDirectory(cfg)
|
||||||
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir, cfg.ErrorLogsMaxFiles)
|
return logging.NewFileRequestLogger(cfg.RequestLog, logsDir, configDir, cfg.ErrorLogsMaxFiles)
|
||||||
}
|
|
||||||
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir, cfg.ErrorLogsMaxFiles)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMiddleware appends additional Gin middleware during server construction.
|
// WithMiddleware appends additional Gin middleware during server construction.
|
||||||
@@ -112,6 +111,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||||
|
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||||
|
return func(cfg *serverOptionConfig) {
|
||||||
|
cfg.postAuthHook = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Server represents the main API server.
|
// Server represents the main API server.
|
||||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -252,7 +258,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||||
s.applyAccessConfig(nil, cfg)
|
s.applyAccessConfig(nil, cfg)
|
||||||
if authManager != nil {
|
if authManager != nil {
|
||||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||||
}
|
}
|
||||||
managementasset.SetCurrentConfig(cfg)
|
managementasset.SetCurrentConfig(cfg)
|
||||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
@@ -263,6 +269,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
}
|
}
|
||||||
logDir := logging.ResolveLogDirectory(cfg)
|
logDir := logging.ResolveLogDirectory(cfg)
|
||||||
s.mgmt.SetLogDirectory(logDir)
|
s.mgmt.SetLogDirectory(logDir)
|
||||||
|
if optionState.postAuthHook != nil {
|
||||||
|
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||||
|
}
|
||||||
s.localPassword = optionState.localPassword
|
s.localPassword = optionState.localPassword
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
@@ -935,7 +944,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update log level dynamically when debug flag changes
|
// Update log level dynamically when debug flag changes
|
||||||
|
|||||||
@@ -7,9 +7,11 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
gin "github.com/gin-gonic/gin"
|
gin "github.com/gin-gonic/gin"
|
||||||
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
@@ -109,3 +111,100 @@ func TestAmpProviderModelRoutes(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||||
|
t.Setenv("WRITABLE_PATH", "")
|
||||||
|
t.Setenv("writable_path", "")
|
||||||
|
|
||||||
|
originalWD, errGetwd := os.Getwd()
|
||||||
|
if errGetwd != nil {
|
||||||
|
t.Fatalf("failed to get current working directory: %v", errGetwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
if errChdir := os.Chdir(tmpDir); errChdir != nil {
|
||||||
|
t.Fatalf("failed to switch working directory: %v", errChdir)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errChdirBack := os.Chdir(originalWD); errChdirBack != nil {
|
||||||
|
t.Fatalf("failed to restore working directory: %v", errChdirBack)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Force ResolveLogDirectory to fallback to auth-dir/logs by making ./logs not a writable directory.
|
||||||
|
if errWriteFile := os.WriteFile(filepath.Join(tmpDir, "logs"), []byte("not-a-directory"), 0o644); errWriteFile != nil {
|
||||||
|
t.Fatalf("failed to create blocking logs file: %v", errWriteFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, "config")
|
||||||
|
if errMkdirConfig := os.MkdirAll(configDir, 0o755); errMkdirConfig != nil {
|
||||||
|
t.Fatalf("failed to create config dir: %v", errMkdirConfig)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(configDir, "config.yaml")
|
||||||
|
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &proxyconfig.Config{
|
||||||
|
SDKConfig: proxyconfig.SDKConfig{
|
||||||
|
RequestLog: false,
|
||||||
|
},
|
||||||
|
AuthDir: authDir,
|
||||||
|
ErrorLogsMaxFiles: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := defaultRequestLoggerFactory(cfg, configPath)
|
||||||
|
fileLogger, ok := logger.(*internallogging.FileRequestLogger)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected *FileRequestLogger, got %T", logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
errLog := fileLogger.LogRequestWithOptions(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
http.MethodPost,
|
||||||
|
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||||
|
[]byte(`{"input":"hello"}`),
|
||||||
|
http.StatusBadGateway,
|
||||||
|
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||||
|
[]byte(`{"error":"upstream failure"}`),
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
true,
|
||||||
|
"issue-1711",
|
||||||
|
time.Now(),
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
if errLog != nil {
|
||||||
|
t.Fatalf("failed to write forced error request log: %v", errLog)
|
||||||
|
}
|
||||||
|
|
||||||
|
authLogsDir := filepath.Join(authDir, "logs")
|
||||||
|
authEntries, errReadAuthDir := os.ReadDir(authLogsDir)
|
||||||
|
if errReadAuthDir != nil {
|
||||||
|
t.Fatalf("failed to read auth logs dir %s: %v", authLogsDir, errReadAuthDir)
|
||||||
|
}
|
||||||
|
foundErrorLogInAuthDir := false
|
||||||
|
for _, entry := range authEntries {
|
||||||
|
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||||
|
foundErrorLogInAuthDir = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundErrorLogInAuthDir {
|
||||||
|
t.Fatalf("expected forced error log in auth fallback dir %s, got entries: %+v", authLogsDir, authEntries)
|
||||||
|
}
|
||||||
|
|
||||||
|
configLogsDir := filepath.Join(configDir, "logs")
|
||||||
|
configEntries, errReadConfigDir := os.ReadDir(configLogsDir)
|
||||||
|
if errReadConfigDir != nil && !os.IsNotExist(errReadConfigDir) {
|
||||||
|
t.Fatalf("failed to inspect config logs dir %s: %v", configLogsDir, errReadConfigDir)
|
||||||
|
}
|
||||||
|
for _, entry := range configEntries {
|
||||||
|
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||||
|
t.Fatalf("unexpected forced error log in config dir %s", configLogsDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
|||||||
|
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
// Encode and write the token data as JSON
|
// Encode and write the token data as JSON
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// utlsRoundTripper implements http.RoundTripper using utls with Firefox fingerprint
|
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
|
||||||
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||||
type utlsRoundTripper struct {
|
type utlsRoundTripper struct {
|
||||||
// mu protects the connections map and pending map
|
// mu protects the connections map and pending map
|
||||||
@@ -100,7 +100,9 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie
|
|||||||
return h2Conn, nil
|
return h2Conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createConnection creates a new HTTP/2 connection with Firefox TLS fingerprint
|
// createConnection creates a new HTTP/2 connection with Chrome TLS fingerprint.
|
||||||
|
// Chrome's TLS fingerprint is closer to Node.js/OpenSSL (which real Claude Code uses)
|
||||||
|
// than Firefox, reducing the mismatch between TLS layer and HTTP headers.
|
||||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||||
conn, err := t.dialer.Dial("tcp", addr)
|
conn, err := t.dialer.Dial("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -108,7 +110,7 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon
|
|||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := &tls.Config{ServerName: host}
|
tlsConfig := &tls.Config{ServerName: host}
|
||||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloFirefox_Auto)
|
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||||
|
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
@@ -156,7 +158,7 @@ func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
|
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
|
||||||
// for Anthropic domains by using utls with Firefox fingerprint.
|
// for Anthropic domains by using utls with Chrome fingerprint.
|
||||||
// It accepts optional SDK configuration for proxy settings.
|
// It accepts optional SDK configuration for proxy settings.
|
||||||
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
|
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
|
|||||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||||
// authorization code and PKCE verifier.
|
// authorization code and PKCE verifier.
|
||||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
|
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||||
|
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||||
|
// login while preserving the existing token parsing and storage behavior.
|
||||||
|
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
if pkceCodes == nil {
|
if pkceCodes == nil {
|
||||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(redirectURI) == "" {
|
||||||
|
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare token exchange request
|
// Prepare token exchange request
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_id": {ClientID},
|
"client_id": {ClientID},
|
||||||
"code": {code},
|
"code": {code},
|
||||||
"redirect_uri": {RedirectURI},
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return tokenData, nil
|
return tokenData, nil
|
||||||
}
|
}
|
||||||
|
if isNonRetryableRefreshErr(err) {
|
||||||
|
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
lastErr = err
|
lastErr = err
|
||||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||||
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isNonRetryableRefreshErr(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(raw, "refresh_token_reused")
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||||
// This is typically called after a successful token refresh to persist the new credentials.
|
// This is typically called after a successful token refresh to persist the new credentials.
|
||||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||||
|
|||||||
44
internal/auth/codex/openai_auth_test.go
Normal file
44
internal/auth/codex/openai_auth_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
auth := &CodexAuth{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for non-retryable refresh failure")
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||||
|
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||||
|
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -82,15 +84,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the GitHub username
|
// Fetch the GitHub username
|
||||||
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("copilot: failed to fetch user info: %v", err)
|
log.Warnf("copilot: failed to fetch user info: %v", err)
|
||||||
username = "unknown"
|
}
|
||||||
|
|
||||||
|
username := userInfo.Login
|
||||||
|
if username == "" {
|
||||||
|
username = "github-user"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &CopilotAuthBundle{
|
return &CopilotAuthBundle{
|
||||||
TokenData: tokenData,
|
TokenData: tokenData,
|
||||||
Username: username,
|
Username: username,
|
||||||
|
Email: userInfo.Email,
|
||||||
|
Name: userInfo.Name,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,12 +158,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
|
|||||||
return false, "", nil
|
return false, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", err
|
return false, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, username, nil
|
return true, userInfo.Login, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
||||||
@@ -165,6 +173,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
|
|||||||
TokenType: bundle.TokenData.TokenType,
|
TokenType: bundle.TokenData.TokenType,
|
||||||
Scope: bundle.TokenData.Scope,
|
Scope: bundle.TokenData.Scope,
|
||||||
Username: bundle.Username,
|
Username: bundle.Username,
|
||||||
|
Email: bundle.Email,
|
||||||
|
Name: bundle.Name,
|
||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -214,6 +224,97 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopilotModelEntry represents a single model entry returned by the Copilot /models API.
|
||||||
|
type CopilotModelEntry struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
||||||
|
type CopilotModelsResponse struct {
|
||||||
|
Data []CopilotModelEntry `json:"data"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// maxModelsResponseSize is the maximum allowed response size from the /models endpoint (2 MB).
|
||||||
|
const maxModelsResponseSize = 2 * 1024 * 1024
|
||||||
|
|
||||||
|
// allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests.
|
||||||
|
var allowedCopilotAPIHosts = map[string]bool{
|
||||||
|
"api.githubcopilot.com": true,
|
||||||
|
"api.individual.githubcopilot.com": true,
|
||||||
|
"api.business.githubcopilot.com": true,
|
||||||
|
"copilot-proxy.githubusercontent.com": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModels fetches the list of available models from the Copilot API.
|
||||||
|
// It requires a valid Copilot API token (not the GitHub access token).
|
||||||
|
func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken) ([]CopilotModelEntry, error) {
|
||||||
|
if apiToken == nil || apiToken.Token == "" {
|
||||||
|
return nil, fmt.Errorf("copilot: api token is required for listing models")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build models URL, validating the endpoint host to prevent SSRF.
|
||||||
|
modelsURL := copilotAPIEndpoint + "/models"
|
||||||
|
if ep := strings.TrimRight(apiToken.Endpoints.API, "/"); ep != "" {
|
||||||
|
parsed, err := url.Parse(ep)
|
||||||
|
if err == nil && parsed.Scheme == "https" && allowedCopilotAPIHosts[parsed.Host] {
|
||||||
|
modelsURL = ep + "/models"
|
||||||
|
} else {
|
||||||
|
log.Warnf("copilot: ignoring untrusted API endpoint %q, using default", ep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := c.MakeAuthenticatedRequest(ctx, http.MethodGet, modelsURL, nil, apiToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to create models request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: models request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("copilot list models: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Limit response body to prevent memory exhaustion.
|
||||||
|
limitedReader := io.LimitReader(resp.Body, maxModelsResponseSize)
|
||||||
|
bodyBytes, err := io.ReadAll(limitedReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to read models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
|
return nil, fmt.Errorf("copilot: list models failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelsResp CopilotModelsResponse
|
||||||
|
if err = json.Unmarshal(bodyBytes, &modelsResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to parse models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelsResp.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModelsWithGitHubToken is a convenience method that exchanges a GitHub access token
|
||||||
|
// for a Copilot API token and then fetches the available models.
|
||||||
|
func (c *CopilotAuth) ListModelsWithGitHubToken(ctx context.Context, githubAccessToken string) ([]CopilotModelEntry, error) {
|
||||||
|
apiToken, err := c.GetCopilotAPIToken(ctx, githubAccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to get API token for model listing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.ListModels(ctx, apiToken)
|
||||||
|
}
|
||||||
|
|
||||||
// buildChatCompletionURL builds the URL for chat completions API.
|
// buildChatCompletionURL builds the URL for chat completions API.
|
||||||
func buildChatCompletionURL() string {
|
func buildChatCompletionURL() string {
|
||||||
return copilotAPIEndpoint + "/chat/completions"
|
return copilotAPIEndpoint + "/chat/completions"
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
|||||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("client_id", copilotClientID)
|
data.Set("client_id", copilotClientID)
|
||||||
data.Set("scope", "user:email")
|
data.Set("scope", "read:user user:email")
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchUserInfo retrieves the GitHub username for the authenticated user.
|
// GitHubUserInfo holds GitHub user profile information.
|
||||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
type GitHubUserInfo struct {
|
||||||
|
// Login is the GitHub username.
|
||||||
|
Login string
|
||||||
|
// Email is the primary email address (may be empty if not public).
|
||||||
|
Email string
|
||||||
|
// Name is the display name.
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
|
||||||
|
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
|||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
|||||||
|
|
||||||
if !isHTTPSuccess(resp.StatusCode) {
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||||
}
|
}
|
||||||
|
|
||||||
var userInfo struct {
|
var raw struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if userInfo.Login == "" {
|
if raw.Login == "" {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return userInfo.Login, nil
|
return GitHubUserInfo{
|
||||||
|
Login: raw.Login,
|
||||||
|
Email: raw.Email,
|
||||||
|
Name: raw.Name,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
213
internal/auth/copilot/oauth_test.go
Normal file
213
internal/auth/copilot/oauth_test.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package copilot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// roundTripFunc lets us inject a custom transport for testing.
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||||
|
|
||||||
|
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
|
||||||
|
// regardless of the original URL host.
|
||||||
|
func newTestClient(srv *httptest.Server) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
req2 := req.Clone(req.Context())
|
||||||
|
req2.URL.Scheme = "http"
|
||||||
|
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
|
||||||
|
return srv.Client().Transport.RoundTrip(req2)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
|
||||||
|
func TestFetchUserInfo_FullProfile(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"login": "octocat",
|
||||||
|
"email": "octocat@github.com",
|
||||||
|
"name": "The Octocat",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if info.Login != "octocat" {
|
||||||
|
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
|
||||||
|
}
|
||||||
|
if info.Email != "octocat@github.com" {
|
||||||
|
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
|
||||||
|
}
|
||||||
|
if info.Name != "The Octocat" {
|
||||||
|
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
|
||||||
|
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
// GitHub returns null for private emails.
|
||||||
|
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if info.Login != "privateuser" {
|
||||||
|
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
|
||||||
|
}
|
||||||
|
if info.Email != "" {
|
||||||
|
t.Errorf("Email: got %q, want empty string", info.Email)
|
||||||
|
}
|
||||||
|
if info.Name != "Private User" {
|
||||||
|
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
|
||||||
|
func TestFetchUserInfo_EmptyToken(t *testing.T) {
|
||||||
|
client := &DeviceFlowClient{httpClient: http.DefaultClient}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty token, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
|
||||||
|
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty login, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
|
||||||
|
func TestFetchUserInfo_HTTPError(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "bad-token")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401 response, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
|
||||||
|
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
|
||||||
|
ts := &CopilotTokenStorage{
|
||||||
|
AccessToken: "ghu_abc",
|
||||||
|
TokenType: "bearer",
|
||||||
|
Scope: "read:user user:email",
|
||||||
|
Username: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
Type: "github-copilot",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
if err = json.Unmarshal(data, &out); err != nil {
|
||||||
|
t.Fatalf("unmarshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
|
||||||
|
if _, ok := out[key]; !ok {
|
||||||
|
t.Errorf("expected key %q in JSON output, not found", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if out["email"] != "octocat@github.com" {
|
||||||
|
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
|
||||||
|
}
|
||||||
|
if out["name"] != "The Octocat" {
|
||||||
|
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
|
||||||
|
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
|
||||||
|
ts := &CopilotTokenStorage{
|
||||||
|
AccessToken: "ghu_abc",
|
||||||
|
Username: "octocat",
|
||||||
|
Type: "github-copilot",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
if err = json.Unmarshal(data, &out); err != nil {
|
||||||
|
t.Fatalf("unmarshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := out["email"]; ok {
|
||||||
|
t.Error("email key should be omitted when empty (omitempty), but was present")
|
||||||
|
}
|
||||||
|
if _, ok := out["name"]; ok {
|
||||||
|
t.Error("name key should be omitted when empty (omitempty), but was present")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
|
||||||
|
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
|
||||||
|
bundle := &CopilotAuthBundle{
|
||||||
|
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
|
||||||
|
Username: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
}
|
||||||
|
if bundle.Email != "octocat@github.com" {
|
||||||
|
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
|
||||||
|
}
|
||||||
|
if bundle.Name != "The Octocat" {
|
||||||
|
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
|
||||||
|
func TestGitHubUserInfo_Struct(t *testing.T) {
|
||||||
|
info := GitHubUserInfo{
|
||||||
|
Login: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
}
|
||||||
|
if info.Login == "" || info.Email == "" || info.Name == "" {
|
||||||
|
t.Error("GitHubUserInfo fields should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
|
|||||||
ExpiresAt string `json:"expires_at,omitempty"`
|
ExpiresAt string `json:"expires_at,omitempty"`
|
||||||
// Username is the GitHub username associated with this token.
|
// Username is the GitHub username associated with this token.
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
// Email is the GitHub email address associated with this token.
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
// Name is the GitHub display name associated with this token.
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
|
|||||||
TokenData *CopilotTokenData
|
TokenData *CopilotTokenData
|
||||||
// Username is the GitHub username.
|
// Username is the GitHub username.
|
||||||
Username string
|
Username string
|
||||||
|
// Email is the GitHub email address.
|
||||||
|
Email string
|
||||||
|
// Name is the GitHub display name.
|
||||||
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceCodeResponse represents GitHub's device code response.
|
// DeviceCodeResponse represents GitHub's device code response.
|
||||||
|
|||||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
|||||||
|
|
||||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
|||||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
misc.LogSavingCredentials(authFilePath)
|
misc.LogSavingCredentials(authFilePath)
|
||||||
ts.Type = "gemini"
|
ts.Type = "gemini"
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
return fmt.Errorf("failed to create directory: %v", err)
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
}
|
}
|
||||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
|||||||
Scope string `json:"scope"`
|
Scope string `json:"scope"`
|
||||||
Cookie string `json:"cookie"`
|
Cookie string `json:"cookie"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serialises the token storage to disk.
|
// SaveTokenToFile serialises the token storage to disk.
|
||||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
defer func() { _ = f.Close() }()
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
|||||||
Expired string `json:"expired,omitempty"`
|
Expired string `json:"expired,omitempty"`
|
||||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
encoder := json.NewEncoder(f)
|
encoder := json.NewEncoder(f)
|
||||||
encoder.SetIndent("", " ")
|
encoder.SetIndent("", " ")
|
||||||
if err = encoder.Encode(ts); err != nil {
|
if err = encoder.Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -7,10 +7,13 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||||
@@ -47,7 +50,7 @@ type KiroTokenData struct {
|
|||||||
Email string `json:"email,omitempty"`
|
Email string `json:"email,omitempty"`
|
||||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||||
StartURL string `json:"startUrl,omitempty"`
|
StartURL string `json:"startUrl,omitempty"`
|
||||||
// Region is the AWS region for IDC authentication (only for IDC auth method)
|
// Region is the OIDC region for IDC login and token refresh
|
||||||
Region string `json:"region,omitempty"`
|
Region string `json:"region,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -520,3 +523,159 @@ func GenerateTokenFileName(tokenData *KiroTokenData) string {
|
|||||||
// Priority 3: Fallback to authMethod only with sequence
|
// Priority 3: Fallback to authMethod only with sequence
|
||||||
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
|
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultKiroRegion is the fallback region when none is specified.
|
||||||
|
const DefaultKiroRegion = "us-east-1"
|
||||||
|
|
||||||
|
// GetCodeWhispererLegacyEndpoint returns the legacy CodeWhisperer JSON-RPC endpoint.
|
||||||
|
// This endpoint supports JSON-RPC style requests with x-amz-target headers.
|
||||||
|
// The Q endpoint (q.{region}.amazonaws.com) does NOT support JSON-RPC style.
|
||||||
|
func GetCodeWhispererLegacyEndpoint(region string) string {
|
||||||
|
if region == "" {
|
||||||
|
region = DefaultKiroRegion
|
||||||
|
}
|
||||||
|
return "https://codewhisperer." + region + ".amazonaws.com"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProfileARN represents a parsed AWS CodeWhisperer profile ARN.
|
||||||
|
// ARN format: arn:partition:service:region:account-id:resource-type/resource-id
|
||||||
|
// Example: arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL
|
||||||
|
type ProfileARN struct {
|
||||||
|
// Raw is the original ARN string
|
||||||
|
Raw string
|
||||||
|
// Partition is the AWS partition (aws)
|
||||||
|
Partition string
|
||||||
|
// Service is the AWS service name (codewhisperer)
|
||||||
|
Service string
|
||||||
|
// Region is the AWS region (us-east-1, ap-southeast-1, etc.)
|
||||||
|
Region string
|
||||||
|
// AccountID is the AWS account ID
|
||||||
|
AccountID string
|
||||||
|
// ResourceType is the resource type (profile)
|
||||||
|
ResourceType string
|
||||||
|
// ResourceID is the resource identifier (e.g., ABCDEFGHIJKL)
|
||||||
|
ResourceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseProfileARN parses an AWS ARN string into a ProfileARN struct.
|
||||||
|
// Returns nil if the ARN is empty, invalid, or not a codewhisperer ARN.
|
||||||
|
func ParseProfileARN(arn string) *ProfileARN {
|
||||||
|
if arn == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// ARN format: arn:partition:service:region:account-id:resource
|
||||||
|
// Minimum 6 parts separated by ":"
|
||||||
|
parts := strings.Split(arn, ":")
|
||||||
|
if len(parts) < 6 {
|
||||||
|
log.Warnf("invalid ARN format: %s", arn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Validate ARN prefix
|
||||||
|
if parts[0] != "arn" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Validate partition
|
||||||
|
partition := parts[1]
|
||||||
|
if partition == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Validate service is codewhisperer
|
||||||
|
service := parts[2]
|
||||||
|
if service != "codewhisperer" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Validate region format (must contain "-")
|
||||||
|
region := parts[3]
|
||||||
|
if region == "" || !strings.Contains(region, "-") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Account ID
|
||||||
|
accountID := parts[4]
|
||||||
|
|
||||||
|
// Parse resource (format: resource-type/resource-id)
|
||||||
|
// Join remaining parts in case resource contains ":"
|
||||||
|
resource := strings.Join(parts[5:], ":")
|
||||||
|
resourceType := ""
|
||||||
|
resourceID := ""
|
||||||
|
if idx := strings.Index(resource, "/"); idx > 0 {
|
||||||
|
resourceType = resource[:idx]
|
||||||
|
resourceID = resource[idx+1:]
|
||||||
|
} else {
|
||||||
|
resourceType = resource
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ProfileARN{
|
||||||
|
Raw: arn,
|
||||||
|
Partition: partition,
|
||||||
|
Service: service,
|
||||||
|
Region: region,
|
||||||
|
AccountID: accountID,
|
||||||
|
ResourceType: resourceType,
|
||||||
|
ResourceID: resourceID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKiroAPIEndpoint returns the Q API endpoint for the specified region.
|
||||||
|
// If region is empty, defaults to us-east-1.
|
||||||
|
func GetKiroAPIEndpoint(region string) string {
|
||||||
|
if region == "" {
|
||||||
|
region = DefaultKiroRegion
|
||||||
|
}
|
||||||
|
return "https://q." + region + ".amazonaws.com"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKiroAPIEndpointFromProfileArn extracts region from profileArn and returns the endpoint.
|
||||||
|
// Returns default us-east-1 endpoint if region cannot be extracted.
|
||||||
|
func GetKiroAPIEndpointFromProfileArn(profileArn string) string {
|
||||||
|
region := ExtractRegionFromProfileArn(profileArn)
|
||||||
|
return GetKiroAPIEndpoint(region)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractRegionFromProfileArn extracts the AWS region from a ProfileARN string.
|
||||||
|
// Returns empty string if ARN is invalid or region cannot be extracted.
|
||||||
|
func ExtractRegionFromProfileArn(profileArn string) string {
|
||||||
|
parsed := ParseProfileARN(profileArn)
|
||||||
|
if parsed == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return parsed.Region
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractRegionFromMetadata extracts API region from auth metadata.
|
||||||
|
// Priority: api_region > profile_arn > DefaultKiroRegion
|
||||||
|
func ExtractRegionFromMetadata(metadata map[string]interface{}) string {
|
||||||
|
if metadata == nil {
|
||||||
|
return DefaultKiroRegion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 1: Explicit api_region override
|
||||||
|
if r, ok := metadata["api_region"].(string); ok && r != "" {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 2: Extract from ProfileARN
|
||||||
|
if profileArn, ok := metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||||
|
if region := ExtractRegionFromProfileArn(profileArn); region != "" {
|
||||||
|
return region
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return DefaultKiroRegion
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildURL(endpoint, path string, queryParams map[string]string) string {
|
||||||
|
fullURL := fmt.Sprintf("%s/%s", endpoint, path)
|
||||||
|
if len(queryParams) > 0 {
|
||||||
|
values := url.Values{}
|
||||||
|
for key, value := range queryParams {
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
values.Set(key, value)
|
||||||
|
}
|
||||||
|
if encoded := values.Encode(); encoded != "" {
|
||||||
|
fullURL = fullURL + "?" + encoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fullURL
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,15 +19,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.)
|
pathGetUsageLimits = "getUsageLimits"
|
||||||
// Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com)
|
pathListAvailableModels = "ListAvailableModels"
|
||||||
// used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct
|
|
||||||
// for their respective API operations.
|
|
||||||
awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com"
|
|
||||||
defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json"
|
|
||||||
targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits"
|
|
||||||
targetListModels = "AmazonCodeWhispererService.ListAvailableModels"
|
|
||||||
targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
|
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
|
||||||
@@ -35,7 +28,6 @@ const (
|
|||||||
// and communicating with the CodeWhisperer API.
|
// and communicating with the CodeWhisperer API.
|
||||||
type KiroAuth struct {
|
type KiroAuth struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
endpoint string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKiroAuth creates a new Kiro authentication service.
|
// NewKiroAuth creates a new Kiro authentication service.
|
||||||
@@ -49,7 +41,6 @@ type KiroAuth struct {
|
|||||||
func NewKiroAuth(cfg *config.Config) *KiroAuth {
|
func NewKiroAuth(cfg *config.Config) *KiroAuth {
|
||||||
return &KiroAuth{
|
return &KiroAuth{
|
||||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
|
||||||
endpoint: awsKiroEndpoint,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,33 +101,30 @@ func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
|
|||||||
return time.Now().After(expiresAt)
|
return time.Now().After(expiresAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeRequest sends a request to the CodeWhisperer API.
|
// makeRequest sends a REST-style GET request to the CodeWhisperer API.
|
||||||
// This is an internal method for making authenticated API calls.
|
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - ctx: The context for the request
|
// - ctx: The context for the request
|
||||||
// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits")
|
// - path: The API path (e.g., "getUsageLimits")
|
||||||
// - accessToken: The OAuth access token
|
// - tokenData: The token data containing access token, refresh token, and profile ARN
|
||||||
// - payload: The request payload
|
// - queryParams: Query parameters to add to the URL
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - []byte: The response body
|
// - []byte: The response body
|
||||||
// - error: An error if the request fails
|
// - error: An error if the request fails
|
||||||
func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) {
|
func (k *KiroAuth) makeRequest(ctx context.Context, path string, tokenData *KiroTokenData, queryParams map[string]string) ([]byte, error) {
|
||||||
jsonBody, err := json.Marshal(payload)
|
// Get endpoint from profileArn (defaults to us-east-1 if empty)
|
||||||
if err != nil {
|
profileArn := queryParams["profileArn"]
|
||||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
|
||||||
}
|
url := buildURL(endpoint, path, queryParams)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody)))
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
|
||||||
req.Header.Set("x-amz-target", target)
|
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
resp, err := k.httpClient.Do(req)
|
resp, err := k.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -171,13 +159,13 @@ func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken s
|
|||||||
// - *KiroUsageInfo: The usage information
|
// - *KiroUsageInfo: The usage information
|
||||||
// - error: An error if the request fails
|
// - error: An error if the request fails
|
||||||
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
|
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
|
||||||
payload := map[string]interface{}{
|
queryParams := map[string]string{
|
||||||
"origin": "AI_EDITOR",
|
"origin": "AI_EDITOR",
|
||||||
"profileArn": tokenData.ProfileArn,
|
"profileArn": tokenData.ProfileArn,
|
||||||
"resourceType": "AGENTIC_REQUEST",
|
"resourceType": "AGENTIC_REQUEST",
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload)
|
body, err := k.makeRequest(ctx, pathGetUsageLimits, tokenData, queryParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -221,12 +209,12 @@ func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData)
|
|||||||
// - []*KiroModel: The list of available models
|
// - []*KiroModel: The list of available models
|
||||||
// - error: An error if the request fails
|
// - error: An error if the request fails
|
||||||
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
|
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
|
||||||
payload := map[string]interface{}{
|
queryParams := map[string]string{
|
||||||
"origin": "AI_EDITOR",
|
"origin": "AI_EDITOR",
|
||||||
"profileArn": tokenData.ProfileArn,
|
"profileArn": tokenData.ProfileArn,
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload)
|
body, err := k.makeRequest(ctx, pathListAvailableModels, tokenData, queryParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package kiro
|
|||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -217,7 +218,8 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tokenData *KiroTokenData
|
tokenData *KiroTokenData
|
||||||
expected string
|
exact string // exact match (for cases with email)
|
||||||
|
prefix string // prefix match (for cases without email, where sequence is appended)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "IDC with email",
|
name: "IDC with email",
|
||||||
@@ -226,7 +228,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
Email: "user@example.com",
|
Email: "user@example.com",
|
||||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||||
},
|
},
|
||||||
expected: "kiro-idc-user-example-com.json",
|
exact: "kiro-idc-user-example-com.json",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IDC without email but with startUrl",
|
name: "IDC without email but with startUrl",
|
||||||
@@ -235,7 +237,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
Email: "",
|
Email: "",
|
||||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||||
},
|
},
|
||||||
expected: "kiro-idc-d-1234567890.json",
|
prefix: "kiro-idc-d-1234567890-",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IDC with company name in startUrl",
|
name: "IDC with company name in startUrl",
|
||||||
@@ -244,7 +246,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
Email: "",
|
Email: "",
|
||||||
StartURL: "https://my-company.awsapps.com/start",
|
StartURL: "https://my-company.awsapps.com/start",
|
||||||
},
|
},
|
||||||
expected: "kiro-idc-my-company.json",
|
prefix: "kiro-idc-my-company-",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IDC without email and without startUrl",
|
name: "IDC without email and without startUrl",
|
||||||
@@ -253,7 +255,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
Email: "",
|
Email: "",
|
||||||
StartURL: "",
|
StartURL: "",
|
||||||
},
|
},
|
||||||
expected: "kiro-idc.json",
|
prefix: "kiro-idc-",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Builder ID with email",
|
name: "Builder ID with email",
|
||||||
@@ -262,7 +264,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
Email: "user@gmail.com",
|
Email: "user@gmail.com",
|
||||||
StartURL: "https://view.awsapps.com/start",
|
StartURL: "https://view.awsapps.com/start",
|
||||||
},
|
},
|
||||||
expected: "kiro-builder-id-user-gmail-com.json",
|
exact: "kiro-builder-id-user-gmail-com.json",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Builder ID without email",
|
name: "Builder ID without email",
|
||||||
@@ -271,7 +273,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
Email: "",
|
Email: "",
|
||||||
StartURL: "https://view.awsapps.com/start",
|
StartURL: "https://view.awsapps.com/start",
|
||||||
},
|
},
|
||||||
expected: "kiro-builder-id.json",
|
prefix: "kiro-builder-id-",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Social auth with email",
|
name: "Social auth with email",
|
||||||
@@ -279,7 +281,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
AuthMethod: "google",
|
AuthMethod: "google",
|
||||||
Email: "user@gmail.com",
|
Email: "user@gmail.com",
|
||||||
},
|
},
|
||||||
expected: "kiro-google-user-gmail-com.json",
|
exact: "kiro-google-user-gmail-com.json",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty auth method",
|
name: "Empty auth method",
|
||||||
@@ -287,7 +289,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
AuthMethod: "",
|
AuthMethod: "",
|
||||||
Email: "",
|
Email: "",
|
||||||
},
|
},
|
||||||
expected: "kiro-unknown.json",
|
prefix: "kiro-unknown-",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Email with special characters",
|
name: "Email with special characters",
|
||||||
@@ -296,16 +298,454 @@ func TestGenerateTokenFileName(t *testing.T) {
|
|||||||
Email: "user.name+tag@sub.example.com",
|
Email: "user.name+tag@sub.example.com",
|
||||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||||
},
|
},
|
||||||
expected: "kiro-idc-user-name+tag-sub-example-com.json",
|
exact: "kiro-idc-user-name+tag-sub-example-com.json",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := GenerateTokenFileName(tt.tokenData)
|
result := GenerateTokenFileName(tt.tokenData)
|
||||||
if result != tt.expected {
|
if tt.exact != "" {
|
||||||
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.expected)
|
if result != tt.exact {
|
||||||
|
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.exact)
|
||||||
|
}
|
||||||
|
} else if tt.prefix != "" {
|
||||||
|
if !strings.HasPrefix(result, tt.prefix) || !strings.HasSuffix(result, ".json") {
|
||||||
|
t.Errorf("GenerateTokenFileName() = %q, want prefix %q with .json suffix", result, tt.prefix)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseProfileARN(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
arn string
|
||||||
|
expected *ProfileARN
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty ARN",
|
||||||
|
arn: "",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid format - too few parts",
|
||||||
|
arn: "arn:aws:codewhisperer",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid prefix - not arn",
|
||||||
|
arn: "notarn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid service - not codewhisperer",
|
||||||
|
arn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid region - no hyphen",
|
||||||
|
arn: "arn:aws:codewhisperer:useast1:123456789012:profile/ABC",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty partition",
|
||||||
|
arn: "arn::codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty region",
|
||||||
|
arn: "arn:aws:codewhisperer::123456789012:profile/ABC",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - us-east-1",
|
||||||
|
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
|
||||||
|
expected: &ProfileARN{
|
||||||
|
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
|
||||||
|
Partition: "aws",
|
||||||
|
Service: "codewhisperer",
|
||||||
|
Region: "us-east-1",
|
||||||
|
AccountID: "123456789012",
|
||||||
|
ResourceType: "profile",
|
||||||
|
ResourceID: "ABCDEFGHIJKL",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - ap-southeast-1",
|
||||||
|
arn: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
|
||||||
|
expected: &ProfileARN{
|
||||||
|
Raw: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
|
||||||
|
Partition: "aws",
|
||||||
|
Service: "codewhisperer",
|
||||||
|
Region: "ap-southeast-1",
|
||||||
|
AccountID: "987654321098",
|
||||||
|
ResourceType: "profile",
|
||||||
|
ResourceID: "ZYXWVUTSRQ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - eu-west-1",
|
||||||
|
arn: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
|
||||||
|
expected: &ProfileARN{
|
||||||
|
Raw: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
|
||||||
|
Partition: "aws",
|
||||||
|
Service: "codewhisperer",
|
||||||
|
Region: "eu-west-1",
|
||||||
|
AccountID: "111222333444",
|
||||||
|
ResourceType: "profile",
|
||||||
|
ResourceID: "PROFILE123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - aws-cn partition",
|
||||||
|
arn: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
|
||||||
|
expected: &ProfileARN{
|
||||||
|
Raw: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
|
||||||
|
Partition: "aws-cn",
|
||||||
|
Service: "codewhisperer",
|
||||||
|
Region: "cn-north-1",
|
||||||
|
AccountID: "123456789012",
|
||||||
|
ResourceType: "profile",
|
||||||
|
ResourceID: "CHINAID",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - resource without slash",
|
||||||
|
arn: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
|
||||||
|
expected: &ProfileARN{
|
||||||
|
Raw: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
|
||||||
|
Partition: "aws",
|
||||||
|
Service: "codewhisperer",
|
||||||
|
Region: "us-west-2",
|
||||||
|
AccountID: "123456789012",
|
||||||
|
ResourceType: "profile",
|
||||||
|
ResourceID: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - resource with colon",
|
||||||
|
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
|
||||||
|
expected: &ProfileARN{
|
||||||
|
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
|
||||||
|
Partition: "aws",
|
||||||
|
Service: "codewhisperer",
|
||||||
|
Region: "us-east-1",
|
||||||
|
AccountID: "123456789012",
|
||||||
|
ResourceType: "profile",
|
||||||
|
ResourceID: "ABC:extra",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ParseProfileARN(tt.arn)
|
||||||
|
if tt.expected == nil {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("ParseProfileARN(%q) = %+v, want nil", tt.arn, result)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result == nil {
|
||||||
|
t.Errorf("ParseProfileARN(%q) = nil, want %+v", tt.arn, tt.expected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result.Raw != tt.expected.Raw {
|
||||||
|
t.Errorf("Raw = %q, want %q", result.Raw, tt.expected.Raw)
|
||||||
|
}
|
||||||
|
if result.Partition != tt.expected.Partition {
|
||||||
|
t.Errorf("Partition = %q, want %q", result.Partition, tt.expected.Partition)
|
||||||
|
}
|
||||||
|
if result.Service != tt.expected.Service {
|
||||||
|
t.Errorf("Service = %q, want %q", result.Service, tt.expected.Service)
|
||||||
|
}
|
||||||
|
if result.Region != tt.expected.Region {
|
||||||
|
t.Errorf("Region = %q, want %q", result.Region, tt.expected.Region)
|
||||||
|
}
|
||||||
|
if result.AccountID != tt.expected.AccountID {
|
||||||
|
t.Errorf("AccountID = %q, want %q", result.AccountID, tt.expected.AccountID)
|
||||||
|
}
|
||||||
|
if result.ResourceType != tt.expected.ResourceType {
|
||||||
|
t.Errorf("ResourceType = %q, want %q", result.ResourceType, tt.expected.ResourceType)
|
||||||
|
}
|
||||||
|
if result.ResourceID != tt.expected.ResourceID {
|
||||||
|
t.Errorf("ResourceID = %q, want %q", result.ResourceID, tt.expected.ResourceID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractRegionFromProfileArn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
profileArn string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty ARN",
|
||||||
|
profileArn: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid ARN",
|
||||||
|
profileArn: "invalid-arn",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - us-east-1",
|
||||||
|
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||||
|
expected: "us-east-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - ap-southeast-1",
|
||||||
|
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||||
|
expected: "ap-southeast-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - eu-central-1",
|
||||||
|
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||||
|
expected: "eu-central-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-codewhisperer ARN",
|
||||||
|
profileArn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ExtractRegionFromProfileArn(tt.profileArn)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ExtractRegionFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKiroAPIEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
region string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty region - defaults to us-east-1",
|
||||||
|
region: "",
|
||||||
|
expected: "https://q.us-east-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "us-east-1",
|
||||||
|
region: "us-east-1",
|
||||||
|
expected: "https://q.us-east-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "us-west-2",
|
||||||
|
region: "us-west-2",
|
||||||
|
expected: "https://q.us-west-2.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ap-southeast-1",
|
||||||
|
region: "ap-southeast-1",
|
||||||
|
expected: "https://q.ap-southeast-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "eu-west-1",
|
||||||
|
region: "eu-west-1",
|
||||||
|
expected: "https://q.eu-west-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cn-north-1",
|
||||||
|
region: "cn-north-1",
|
||||||
|
expected: "https://q.cn-north-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetKiroAPIEndpoint(tt.region)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetKiroAPIEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKiroAPIEndpointFromProfileArn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
profileArn string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty ARN - defaults to us-east-1",
|
||||||
|
profileArn: "",
|
||||||
|
expected: "https://q.us-east-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid ARN - defaults to us-east-1",
|
||||||
|
profileArn: "invalid-arn",
|
||||||
|
expected: "https://q.us-east-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - us-east-1",
|
||||||
|
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||||
|
expected: "https://q.us-east-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - ap-southeast-1",
|
||||||
|
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||||
|
expected: "https://q.ap-southeast-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid ARN - eu-central-1",
|
||||||
|
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||||
|
expected: "https://q.eu-central-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetKiroAPIEndpointFromProfileArn(tt.profileArn)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetKiroAPIEndpointFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCodeWhispererLegacyEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
region string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty region - defaults to us-east-1",
|
||||||
|
region: "",
|
||||||
|
expected: "https://codewhisperer.us-east-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "us-east-1",
|
||||||
|
region: "us-east-1",
|
||||||
|
expected: "https://codewhisperer.us-east-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "us-west-2",
|
||||||
|
region: "us-west-2",
|
||||||
|
expected: "https://codewhisperer.us-west-2.amazonaws.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ap-northeast-1",
|
||||||
|
region: "ap-northeast-1",
|
||||||
|
expected: "https://codewhisperer.ap-northeast-1.amazonaws.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetCodeWhispererLegacyEndpoint(tt.region)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetCodeWhispererLegacyEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractRegionFromMetadata(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
metadata map[string]interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Nil metadata - defaults to us-east-1",
|
||||||
|
metadata: nil,
|
||||||
|
expected: "us-east-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty metadata - defaults to us-east-1",
|
||||||
|
metadata: map[string]interface{}{},
|
||||||
|
expected: "us-east-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Priority 1: api_region override",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"api_region": "eu-west-1",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||||
|
},
|
||||||
|
expected: "eu-west-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Priority 2: profile_arn when api_region is empty",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"api_region": "",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||||
|
},
|
||||||
|
expected: "ap-southeast-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Priority 2: profile_arn when api_region is missing",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||||
|
},
|
||||||
|
expected: "eu-central-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Fallback: default when profile_arn is invalid",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"profile_arn": "invalid-arn",
|
||||||
|
},
|
||||||
|
expected: "us-east-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Fallback: default when profile_arn is empty",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"profile_arn": "",
|
||||||
|
},
|
||||||
|
expected: "us-east-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OIDC region is NOT used for API region",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"region": "ap-northeast-2", // OIDC region - should be ignored
|
||||||
|
},
|
||||||
|
expected: "us-east-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "api_region takes precedence over OIDC region",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"api_region": "us-west-2",
|
||||||
|
"region": "ap-northeast-2", // OIDC region - should be ignored
|
||||||
|
},
|
||||||
|
expected: "us-west-2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-string api_region is ignored",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"api_region": 123, // wrong type
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:ap-south-1:123456789012:profile/ABC",
|
||||||
|
},
|
||||||
|
expected: "ap-south-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-string profile_arn is ignored",
|
||||||
|
metadata: map[string]interface{}{
|
||||||
|
"profile_arn": 123, // wrong type
|
||||||
|
},
|
||||||
|
expected: "us-east-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ExtractRegionFromMetadata(tt.metadata)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ExtractRegionFromMetadata(%v) = %q, want %q", tt.metadata, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,30 +9,23 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
codeWhispererAPI = "https://codewhisperer.us-east-1.amazonaws.com"
|
|
||||||
kiroVersion = "0.6.18"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CodeWhispererClient handles CodeWhisperer API calls.
|
// CodeWhispererClient handles CodeWhisperer API calls.
|
||||||
type CodeWhispererClient struct {
|
type CodeWhispererClient struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
machineID string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageLimitsResponse represents the getUsageLimits API response.
|
// UsageLimitsResponse represents the getUsageLimits API response.
|
||||||
type UsageLimitsResponse struct {
|
type UsageLimitsResponse struct {
|
||||||
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
|
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
|
||||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||||
UserInfo *UserInfo `json:"userInfo,omitempty"`
|
UserInfo *UserInfo `json:"userInfo,omitempty"`
|
||||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||||
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
|
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserInfo contains user information from the API.
|
// UserInfo contains user information from the API.
|
||||||
@@ -49,13 +42,13 @@ type SubscriptionInfo struct {
|
|||||||
|
|
||||||
// UsageBreakdown contains usage details.
|
// UsageBreakdown contains usage details.
|
||||||
type UsageBreakdown struct {
|
type UsageBreakdown struct {
|
||||||
UsageLimit *int `json:"usageLimit,omitempty"`
|
UsageLimit *int `json:"usageLimit,omitempty"`
|
||||||
CurrentUsage *int `json:"currentUsage,omitempty"`
|
CurrentUsage *int `json:"currentUsage,omitempty"`
|
||||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
|
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
|
||||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
|
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
|
||||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||||
DisplayName string `json:"displayName,omitempty"`
|
DisplayName string `json:"displayName,omitempty"`
|
||||||
ResourceType string `json:"resourceType,omitempty"`
|
ResourceType string `json:"resourceType,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCodeWhispererClient creates a new CodeWhisperer client.
|
// NewCodeWhispererClient creates a new CodeWhisperer client.
|
||||||
@@ -64,40 +57,34 @@ func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhisperer
|
|||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
}
|
}
|
||||||
if machineID == "" {
|
|
||||||
machineID = uuid.New().String()
|
|
||||||
}
|
|
||||||
return &CodeWhispererClient{
|
return &CodeWhispererClient{
|
||||||
httpClient: client,
|
httpClient: client,
|
||||||
machineID: machineID,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateInvocationID generates a unique invocation ID.
|
|
||||||
func generateInvocationID() string {
|
|
||||||
return uuid.New().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
|
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
|
||||||
// This is the recommended way to get user email after login.
|
// This is the recommended way to get user email after login.
|
||||||
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken string) (*UsageLimitsResponse, error) {
|
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken, clientID, refreshToken, profileArn string) (*UsageLimitsResponse, error) {
|
||||||
url := fmt.Sprintf("%s/getUsageLimits?isEmailRequired=true&origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", codeWhispererAPI)
|
queryParams := map[string]string{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
"resourceType": "AGENTIC_REQUEST",
|
||||||
|
}
|
||||||
|
// Determine endpoint based on profileArn region
|
||||||
|
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
|
||||||
|
if profileArn != "" {
|
||||||
|
queryParams["profileArn"] = profileArn
|
||||||
|
} else {
|
||||||
|
queryParams["isEmailRequired"] = "true"
|
||||||
|
}
|
||||||
|
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers to match Kiro IDE
|
accountKey := GetAccountKey(clientID, refreshToken)
|
||||||
xAmzUserAgent := fmt.Sprintf("aws-sdk-js/1.0.0 KiroIDE-%s-%s", kiroVersion, c.machineID)
|
setRuntimeHeaders(req, accessToken, accountKey)
|
||||||
userAgent := fmt.Sprintf("aws-sdk-js/1.0.0 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererruntime#1.0.0 m/E KiroIDE-%s-%s", kiroVersion, c.machineID)
|
|
||||||
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("x-amz-user-agent", xAmzUserAgent)
|
|
||||||
req.Header.Set("User-Agent", userAgent)
|
|
||||||
req.Header.Set("amz-sdk-invocation-id", generateInvocationID())
|
|
||||||
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
|
|
||||||
req.Header.Set("Connection", "close")
|
|
||||||
|
|
||||||
log.Debugf("codewhisperer: GET %s", url)
|
log.Debugf("codewhisperer: GET %s", url)
|
||||||
|
|
||||||
@@ -128,8 +115,8 @@ func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken st
|
|||||||
|
|
||||||
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
|
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
|
||||||
// This is more reliable than JWT parsing as it uses the official API.
|
// This is more reliable than JWT parsing as it uses the official API.
|
||||||
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken string) string {
|
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken, clientID, refreshToken string) string {
|
||||||
resp, err := c.GetUsageLimits(ctx, accessToken)
|
resp, err := c.GetUsageLimits(ctx, accessToken, clientID, refreshToken, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
|
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
|
||||||
return ""
|
return ""
|
||||||
@@ -146,10 +133,10 @@ func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessT
|
|||||||
|
|
||||||
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
|
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
|
||||||
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
|
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
|
||||||
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken string) string {
|
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken, clientID, refreshToken string) string {
|
||||||
// Method 1: Try CodeWhisperer API (most reliable)
|
// Method 1: Try CodeWhisperer API (most reliable)
|
||||||
cwClient := NewCodeWhispererClient(cfg, "")
|
cwClient := NewCodeWhispererClient(cfg, "")
|
||||||
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken)
|
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken, clientID, refreshToken)
|
||||||
if email != "" {
|
if email != "" {
|
||||||
return email
|
return email
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,77 +2,105 @@ package kiro
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Fingerprint 多维度指纹信息
|
// Fingerprint holds multi-dimensional fingerprint data for runtime request disguise.
|
||||||
type Fingerprint struct {
|
type Fingerprint struct {
|
||||||
SDKVersion string // 1.0.20-1.0.27
|
OIDCSDKVersion string // 3.7xx (AWS SDK JS)
|
||||||
|
RuntimeSDKVersion string // 1.0.x (runtime API)
|
||||||
|
StreamingSDKVersion string // 1.0.x (streaming API)
|
||||||
OSType string // darwin/windows/linux
|
OSType string // darwin/windows/linux
|
||||||
OSVersion string // 10.0.22621
|
OSVersion string
|
||||||
NodeVersion string // 18.x/20.x/22.x
|
NodeVersion string
|
||||||
KiroVersion string // 0.3.x-0.8.x
|
KiroVersion string
|
||||||
KiroHash string // SHA256
|
KiroHash string // SHA256
|
||||||
AcceptLanguage string
|
|
||||||
ScreenResolution string // 1920x1080
|
|
||||||
ColorDepth int // 24
|
|
||||||
HardwareConcurrency int // CPU 核心数
|
|
||||||
TimezoneOffset int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FingerprintManager 指纹管理器
|
// FingerprintConfig holds external fingerprint overrides.
|
||||||
|
type FingerprintConfig struct {
|
||||||
|
OIDCSDKVersion string
|
||||||
|
RuntimeSDKVersion string
|
||||||
|
StreamingSDKVersion string
|
||||||
|
OSType string
|
||||||
|
OSVersion string
|
||||||
|
NodeVersion string
|
||||||
|
KiroVersion string
|
||||||
|
KiroHash string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FingerprintManager manages per-account fingerprint generation and caching.
|
||||||
type FingerprintManager struct {
|
type FingerprintManager struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
||||||
rng *rand.Rand
|
rng *rand.Rand
|
||||||
|
config *FingerprintConfig // External config (Optional)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
sdkVersions = []string{
|
// SDK versions
|
||||||
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
|
oidcSDKVersions = []string{
|
||||||
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
|
"3.980.0", "3.975.0", "3.972.0", "3.808.0",
|
||||||
|
"3.738.0", "3.737.0", "3.736.0", "3.735.0",
|
||||||
}
|
}
|
||||||
|
// SDKVersions for getUsageLimits/ListAvailableModels/GetProfile (runtime API)
|
||||||
|
runtimeSDKVersions = []string{"1.0.0"}
|
||||||
|
// SDKVersions for generateAssistantResponse (streaming API)
|
||||||
|
streamingSDKVersions = []string{"1.0.27"}
|
||||||
|
// Valid OS types
|
||||||
osTypes = []string{"darwin", "windows", "linux"}
|
osTypes = []string{"darwin", "windows", "linux"}
|
||||||
|
// OS versions
|
||||||
osVersions = map[string][]string{
|
osVersions = map[string][]string{
|
||||||
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
|
"darwin": {"25.2.0", "25.1.0", "25.0.0", "24.5.0", "24.4.0", "24.3.0"},
|
||||||
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
|
"windows": {"10.0.26200", "10.0.26100", "10.0.22631", "10.0.22621", "10.0.19045"},
|
||||||
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
|
"linux": {"6.12.0", "6.11.0", "6.8.0", "6.6.0", "6.5.0", "6.1.0"},
|
||||||
}
|
}
|
||||||
|
// Node versions
|
||||||
nodeVersions = []string{
|
nodeVersions = []string{
|
||||||
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
|
"22.21.1", "22.21.0", "22.20.0", "22.19.0", "22.18.0",
|
||||||
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
|
"20.18.0", "20.17.0", "20.16.0",
|
||||||
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
|
|
||||||
}
|
}
|
||||||
|
// Kiro IDE versions
|
||||||
kiroVersions = []string{
|
kiroVersions = []string{
|
||||||
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
|
"0.10.32", "0.10.16", "0.10.10",
|
||||||
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
|
"0.9.47", "0.9.40", "0.9.2",
|
||||||
|
"0.8.206", "0.8.140", "0.8.135", "0.8.86",
|
||||||
}
|
}
|
||||||
acceptLanguages = []string{
|
// Global singleton
|
||||||
"en-US,en;q=0.9",
|
globalFingerprintManager *FingerprintManager
|
||||||
"en-GB,en;q=0.9",
|
globalFingerprintManagerOnce sync.Once
|
||||||
"zh-CN,zh;q=0.9,en;q=0.8",
|
|
||||||
"zh-TW,zh;q=0.9,en;q=0.8",
|
|
||||||
"ja-JP,ja;q=0.9,en;q=0.8",
|
|
||||||
"ko-KR,ko;q=0.9,en;q=0.8",
|
|
||||||
"de-DE,de;q=0.9,en;q=0.8",
|
|
||||||
"fr-FR,fr;q=0.9,en;q=0.8",
|
|
||||||
}
|
|
||||||
screenResolutions = []string{
|
|
||||||
"1920x1080", "2560x1440", "3840x2160",
|
|
||||||
"1366x768", "1440x900", "1680x1050",
|
|
||||||
"2560x1600", "3440x1440",
|
|
||||||
}
|
|
||||||
colorDepths = []int{24, 32}
|
|
||||||
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
|
|
||||||
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFingerprintManager 创建指纹管理器
|
func GlobalFingerprintManager() *FingerprintManager {
|
||||||
|
globalFingerprintManagerOnce.Do(func() {
|
||||||
|
globalFingerprintManager = NewFingerprintManager()
|
||||||
|
})
|
||||||
|
return globalFingerprintManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetGlobalFingerprintConfig(cfg *FingerprintConfig) {
|
||||||
|
GlobalFingerprintManager().SetConfig(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConfig applies the config and clears the fingerprint cache.
|
||||||
|
func (fm *FingerprintManager) SetConfig(cfg *FingerprintConfig) {
|
||||||
|
fm.mu.Lock()
|
||||||
|
defer fm.mu.Unlock()
|
||||||
|
fm.config = cfg
|
||||||
|
// Clear cached fingerprints so they regenerate with the new config
|
||||||
|
fm.fingerprints = make(map[string]*Fingerprint)
|
||||||
|
}
|
||||||
|
|
||||||
func NewFingerprintManager() *FingerprintManager {
|
func NewFingerprintManager() *FingerprintManager {
|
||||||
return &FingerprintManager{
|
return &FingerprintManager{
|
||||||
fingerprints: make(map[string]*Fingerprint),
|
fingerprints: make(map[string]*Fingerprint),
|
||||||
@@ -80,7 +108,7 @@ func NewFingerprintManager() *FingerprintManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFingerprint 获取或生成 Token 关联的指纹
|
// GetFingerprint returns the fingerprint for tokenKey, creating one if it doesn't exist.
|
||||||
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||||
fm.mu.RLock()
|
fm.mu.RLock()
|
||||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||||
@@ -101,97 +129,150 @@ func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
|||||||
return fp
|
return fp
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateFingerprint 生成新的指纹
|
|
||||||
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
||||||
osType := fm.randomChoice(osTypes)
|
if fm.config != nil {
|
||||||
osVersion := fm.randomChoice(osVersions[osType])
|
return fm.generateFromConfig(tokenKey)
|
||||||
kiroVersion := fm.randomChoice(kiroVersions)
|
}
|
||||||
|
return fm.generateRandom(tokenKey)
|
||||||
|
}
|
||||||
|
|
||||||
fp := &Fingerprint{
|
// generateFromConfig uses config values, falling back to random for empty fields.
|
||||||
SDKVersion: fm.randomChoice(sdkVersions),
|
func (fm *FingerprintManager) generateFromConfig(tokenKey string) *Fingerprint {
|
||||||
OSType: osType,
|
cfg := fm.config
|
||||||
OSVersion: osVersion,
|
|
||||||
NodeVersion: fm.randomChoice(nodeVersions),
|
// Helper: config value or random selection
|
||||||
KiroVersion: kiroVersion,
|
configOrRandom := func(configVal string, choices []string) string {
|
||||||
AcceptLanguage: fm.randomChoice(acceptLanguages),
|
if configVal != "" {
|
||||||
ScreenResolution: fm.randomChoice(screenResolutions),
|
return configVal
|
||||||
ColorDepth: fm.randomIntChoice(colorDepths),
|
}
|
||||||
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
|
return choices[fm.rng.Intn(len(choices))]
|
||||||
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
|
osType := cfg.OSType
|
||||||
return fp
|
if osType == "" {
|
||||||
|
osType = runtime.GOOS
|
||||||
|
if !slices.Contains(osTypes, osType) {
|
||||||
|
osType = osTypes[fm.rng.Intn(len(osTypes))]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
osVersion := cfg.OSVersion
|
||||||
|
if osVersion == "" {
|
||||||
|
if versions, ok := osVersions[osType]; ok {
|
||||||
|
osVersion = versions[fm.rng.Intn(len(versions))]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kiroHash := cfg.KiroHash
|
||||||
|
if kiroHash == "" {
|
||||||
|
hash := sha256.Sum256([]byte(tokenKey))
|
||||||
|
kiroHash = hex.EncodeToString(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Fingerprint{
|
||||||
|
OIDCSDKVersion: configOrRandom(cfg.OIDCSDKVersion, oidcSDKVersions),
|
||||||
|
RuntimeSDKVersion: configOrRandom(cfg.RuntimeSDKVersion, runtimeSDKVersions),
|
||||||
|
StreamingSDKVersion: configOrRandom(cfg.StreamingSDKVersion, streamingSDKVersions),
|
||||||
|
OSType: osType,
|
||||||
|
OSVersion: osVersion,
|
||||||
|
NodeVersion: configOrRandom(cfg.NodeVersion, nodeVersions),
|
||||||
|
KiroVersion: configOrRandom(cfg.KiroVersion, kiroVersions),
|
||||||
|
KiroHash: kiroHash,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateKiroHash 生成 Kiro Hash
|
// generateRandom generates a deterministic fingerprint seeded by accountKey hash.
|
||||||
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
|
func (fm *FingerprintManager) generateRandom(accountKey string) *Fingerprint {
|
||||||
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
|
// Use accountKey hash as seed for deterministic random selection
|
||||||
hash := sha256.Sum256([]byte(data))
|
hash := sha256.Sum256([]byte(accountKey))
|
||||||
return hex.EncodeToString(hash[:])
|
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||||
|
rng := rand.New(rand.NewSource(seed))
|
||||||
|
|
||||||
|
osType := runtime.GOOS
|
||||||
|
if !slices.Contains(osTypes, osType) {
|
||||||
|
osType = osTypes[rng.Intn(len(osTypes))]
|
||||||
|
}
|
||||||
|
osVersion := osVersions[osType][rng.Intn(len(osVersions[osType]))]
|
||||||
|
|
||||||
|
return &Fingerprint{
|
||||||
|
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
|
||||||
|
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
|
||||||
|
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
|
||||||
|
OSType: osType,
|
||||||
|
OSVersion: osVersion,
|
||||||
|
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
|
||||||
|
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
|
||||||
|
KiroHash: hex.EncodeToString(hash[:]),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// randomChoice 随机选择字符串
|
// GenerateAccountKey returns a 16-char hex key derived from SHA256(seed).
|
||||||
func (fm *FingerprintManager) randomChoice(choices []string) string {
|
func GenerateAccountKey(seed string) string {
|
||||||
return choices[fm.rng.Intn(len(choices))]
|
hash := sha256.Sum256([]byte(seed))
|
||||||
|
return hex.EncodeToString(hash[:8])
|
||||||
}
|
}
|
||||||
|
|
||||||
// randomIntChoice 随机选择整数
|
// GetAccountKey derives an account key from clientID > refreshToken > random UUID.
|
||||||
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
|
func GetAccountKey(clientID, refreshToken string) string {
|
||||||
return choices[fm.rng.Intn(len(choices))]
|
// 1. Prefer ClientID
|
||||||
|
if clientID != "" {
|
||||||
|
return GenerateAccountKey(clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Fallback to RefreshToken
|
||||||
|
if refreshToken != "" {
|
||||||
|
return GenerateAccountKey(refreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Random fallback
|
||||||
|
return GenerateAccountKey(uuid.New().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
|
// BuildUserAgent format: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||||
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
|
|
||||||
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
|
|
||||||
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
|
|
||||||
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
|
|
||||||
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
|
|
||||||
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
|
|
||||||
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
|
|
||||||
req.Header.Set("Accept-Language", fp.AcceptLanguage)
|
|
||||||
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
|
|
||||||
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
|
|
||||||
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
|
|
||||||
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveFingerprint 移除 Token 关联的指纹
|
|
||||||
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
|
|
||||||
fm.mu.Lock()
|
|
||||||
defer fm.mu.Unlock()
|
|
||||||
delete(fm.fingerprints, tokenKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count 返回当前管理的指纹数量
|
|
||||||
func (fm *FingerprintManager) Count() int {
|
|
||||||
fm.mu.RLock()
|
|
||||||
defer fm.mu.RUnlock()
|
|
||||||
return len(fm.fingerprints)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
|
|
||||||
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
|
||||||
func (fp *Fingerprint) BuildUserAgent() string {
|
func (fp *Fingerprint) BuildUserAgent() string {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||||
fp.SDKVersion,
|
fp.StreamingSDKVersion,
|
||||||
fp.OSType,
|
fp.OSType,
|
||||||
fp.OSVersion,
|
fp.OSVersion,
|
||||||
fp.NodeVersion,
|
fp.NodeVersion,
|
||||||
fp.SDKVersion,
|
fp.StreamingSDKVersion,
|
||||||
fp.KiroVersion,
|
fp.KiroVersion,
|
||||||
fp.KiroHash,
|
fp.KiroHash,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
|
// BuildAmzUserAgent format: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||||
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
|
||||||
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||||
fp.SDKVersion,
|
fp.StreamingSDKVersion,
|
||||||
fp.KiroVersion,
|
fp.KiroVersion,
|
||||||
fp.KiroHash,
|
fp.KiroHash,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetOIDCHeaders(req *http.Request) {
|
||||||
|
fp := GlobalFingerprintManager().GetFingerprint("oidc-session")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion))
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf(
|
||||||
|
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/%s#%s m/E KiroIDE",
|
||||||
|
fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, "sso-oidc", fp.OIDCSDKVersion))
|
||||||
|
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
|
||||||
|
req.Header.Set("amz-sdk-request", "attempt=1; max=4")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setRuntimeHeaders(req *http.Request, accessToken string, accountKey string) {
|
||||||
|
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||||
|
machineID := fp.KiroHash
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s",
|
||||||
|
fp.RuntimeSDKVersion, fp.KiroVersion, machineID))
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf(
|
||||||
|
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererruntime#%s m/N,E KiroIDE-%s-%s",
|
||||||
|
fp.RuntimeSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.RuntimeSDKVersion,
|
||||||
|
fp.KiroVersion, machineID))
|
||||||
|
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
|
||||||
|
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package kiro
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -26,8 +28,14 @@ func TestGetFingerprint_NewToken(t *testing.T) {
|
|||||||
if fp == nil {
|
if fp == nil {
|
||||||
t.Fatal("expected non-nil Fingerprint")
|
t.Fatal("expected non-nil Fingerprint")
|
||||||
}
|
}
|
||||||
if fp.SDKVersion == "" {
|
if fp.OIDCSDKVersion == "" {
|
||||||
t.Error("expected non-empty SDKVersion")
|
t.Error("expected non-empty OIDCSDKVersion")
|
||||||
|
}
|
||||||
|
if fp.RuntimeSDKVersion == "" {
|
||||||
|
t.Error("expected non-empty RuntimeSDKVersion")
|
||||||
|
}
|
||||||
|
if fp.StreamingSDKVersion == "" {
|
||||||
|
t.Error("expected non-empty StreamingSDKVersion")
|
||||||
}
|
}
|
||||||
if fp.OSType == "" {
|
if fp.OSType == "" {
|
||||||
t.Error("expected non-empty OSType")
|
t.Error("expected non-empty OSType")
|
||||||
@@ -44,18 +52,6 @@ func TestGetFingerprint_NewToken(t *testing.T) {
|
|||||||
if fp.KiroHash == "" {
|
if fp.KiroHash == "" {
|
||||||
t.Error("expected non-empty KiroHash")
|
t.Error("expected non-empty KiroHash")
|
||||||
}
|
}
|
||||||
if fp.AcceptLanguage == "" {
|
|
||||||
t.Error("expected non-empty AcceptLanguage")
|
|
||||||
}
|
|
||||||
if fp.ScreenResolution == "" {
|
|
||||||
t.Error("expected non-empty ScreenResolution")
|
|
||||||
}
|
|
||||||
if fp.ColorDepth == 0 {
|
|
||||||
t.Error("expected non-zero ColorDepth")
|
|
||||||
}
|
|
||||||
if fp.HardwareConcurrency == 0 {
|
|
||||||
t.Error("expected non-zero HardwareConcurrency")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
||||||
@@ -78,72 +74,18 @@ func TestGetFingerprint_DifferentTokens(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveFingerprint(t *testing.T) {
|
func TestBuildUserAgent(t *testing.T) {
|
||||||
fm := NewFingerprintManager()
|
|
||||||
fm.GetFingerprint("token1")
|
|
||||||
if fm.Count() != 1 {
|
|
||||||
t.Fatalf("expected count 1, got %d", fm.Count())
|
|
||||||
}
|
|
||||||
|
|
||||||
fm.RemoveFingerprint("token1")
|
|
||||||
if fm.Count() != 0 {
|
|
||||||
t.Errorf("expected count 0, got %d", fm.Count())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRemoveFingerprint_NonExistent(t *testing.T) {
|
|
||||||
fm := NewFingerprintManager()
|
|
||||||
fm.RemoveFingerprint("nonexistent")
|
|
||||||
if fm.Count() != 0 {
|
|
||||||
t.Errorf("expected count 0, got %d", fm.Count())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCount(t *testing.T) {
|
|
||||||
fm := NewFingerprintManager()
|
|
||||||
if fm.Count() != 0 {
|
|
||||||
t.Errorf("expected count 0, got %d", fm.Count())
|
|
||||||
}
|
|
||||||
|
|
||||||
fm.GetFingerprint("token1")
|
|
||||||
fm.GetFingerprint("token2")
|
|
||||||
fm.GetFingerprint("token3")
|
|
||||||
|
|
||||||
if fm.Count() != 3 {
|
|
||||||
t.Errorf("expected count 3, got %d", fm.Count())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestApplyToRequest(t *testing.T) {
|
|
||||||
fm := NewFingerprintManager()
|
fm := NewFingerprintManager()
|
||||||
fp := fm.GetFingerprint("token1")
|
fp := fm.GetFingerprint("token1")
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
ua := fp.BuildUserAgent()
|
||||||
fp.ApplyToRequest(req)
|
if ua == "" {
|
||||||
|
t.Error("expected non-empty User-Agent")
|
||||||
|
}
|
||||||
|
|
||||||
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
|
amzUA := fp.BuildAmzUserAgent()
|
||||||
t.Error("X-Kiro-SDK-Version header mismatch")
|
if amzUA == "" {
|
||||||
}
|
t.Error("expected non-empty X-Amz-User-Agent")
|
||||||
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
|
|
||||||
t.Error("X-Kiro-OS-Type header mismatch")
|
|
||||||
}
|
|
||||||
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
|
|
||||||
t.Error("X-Kiro-OS-Version header mismatch")
|
|
||||||
}
|
|
||||||
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
|
|
||||||
t.Error("X-Kiro-Node-Version header mismatch")
|
|
||||||
}
|
|
||||||
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
|
|
||||||
t.Error("X-Kiro-Version header mismatch")
|
|
||||||
}
|
|
||||||
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
|
|
||||||
t.Error("X-Kiro-Hash header mismatch")
|
|
||||||
}
|
|
||||||
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
|
|
||||||
t.Error("Accept-Language header mismatch")
|
|
||||||
}
|
|
||||||
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
|
|
||||||
t.Error("X-Screen-Resolution header mismatch")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,6 +108,33 @@ func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateFromConfig_OSTypeFromRuntimeGOOS(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
|
||||||
|
// Set config with empty OSType to trigger runtime.GOOS fallback
|
||||||
|
fm.SetConfig(&FingerprintConfig{
|
||||||
|
OIDCSDKVersion: "3.738.0", // Set other fields to use config path
|
||||||
|
})
|
||||||
|
|
||||||
|
fp := fm.GetFingerprint("test-token")
|
||||||
|
|
||||||
|
// Expected OS type based on runtime.GOOS mapping
|
||||||
|
var expectedOS string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
expectedOS = "darwin"
|
||||||
|
case "windows":
|
||||||
|
expectedOS = "windows"
|
||||||
|
default:
|
||||||
|
expectedOS = "linux"
|
||||||
|
}
|
||||||
|
|
||||||
|
if fp.OSType != expectedOS {
|
||||||
|
t.Errorf("expected OSType '%s' from runtime.GOOS '%s', got '%s'",
|
||||||
|
expectedOS, runtime.GOOS, fp.OSType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||||
fm := NewFingerprintManager()
|
fm := NewFingerprintManager()
|
||||||
const numGoroutines = 100
|
const numGoroutines = 100
|
||||||
@@ -174,22 +143,18 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(numGoroutines)
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := range numGoroutines {
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for j := 0; j < numOperations; j++ {
|
for j := range numOperations {
|
||||||
tokenKey := "token" + string(rune('a'+id%26))
|
tokenKey := "token" + string(rune('a'+id%26))
|
||||||
switch j % 4 {
|
switch j % 2 {
|
||||||
case 0:
|
case 0:
|
||||||
fm.GetFingerprint(tokenKey)
|
fm.GetFingerprint(tokenKey)
|
||||||
case 1:
|
case 1:
|
||||||
fm.Count()
|
|
||||||
case 2:
|
|
||||||
fp := fm.GetFingerprint(tokenKey)
|
fp := fm.GetFingerprint(tokenKey)
|
||||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
_ = fp.BuildUserAgent()
|
||||||
fp.ApplyToRequest(req)
|
_ = fp.BuildAmzUserAgent()
|
||||||
case 3:
|
|
||||||
fm.RemoveFingerprint(tokenKey)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
@@ -198,16 +163,20 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKiroHashUniqueness(t *testing.T) {
|
func TestKiroHashStability(t *testing.T) {
|
||||||
fm := NewFingerprintManager()
|
fm := NewFingerprintManager()
|
||||||
hashes := make(map[string]bool)
|
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
// Same token should always return same hash
|
||||||
fp := fm.GetFingerprint("token" + string(rune(i)))
|
fp1 := fm.GetFingerprint("token1")
|
||||||
if hashes[fp.KiroHash] {
|
fp2 := fm.GetFingerprint("token1")
|
||||||
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
|
if fp1.KiroHash != fp2.KiroHash {
|
||||||
}
|
t.Errorf("same token should have same hash: %s vs %s", fp1.KiroHash, fp2.KiroHash)
|
||||||
hashes[fp.KiroHash] = true
|
}
|
||||||
|
|
||||||
|
// Different tokens should have different hashes
|
||||||
|
fp3 := fm.GetFingerprint("token2")
|
||||||
|
if fp1.KiroHash == fp3.KiroHash {
|
||||||
|
t.Errorf("different tokens should have different hashes")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,8 +189,590 @@ func TestKiroHashFormat(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range fp.KiroHash {
|
for _, c := range fp.KiroHash {
|
||||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
|
||||||
t.Errorf("invalid hex character in KiroHash: %c", c)
|
t.Errorf("invalid hex character in KiroHash: %c", c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGlobalFingerprintManager(t *testing.T) {
|
||||||
|
fm1 := GlobalFingerprintManager()
|
||||||
|
fm2 := GlobalFingerprintManager()
|
||||||
|
|
||||||
|
if fm1 == nil {
|
||||||
|
t.Fatal("expected non-nil GlobalFingerprintManager")
|
||||||
|
}
|
||||||
|
if fm1 != fm2 {
|
||||||
|
t.Error("expected GlobalFingerprintManager to return same instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetOIDCHeaders(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
SetOIDCHeaders(req)
|
||||||
|
|
||||||
|
if req.Header.Get("Content-Type") != "application/json" {
|
||||||
|
t.Error("expected Content-Type header to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
amzUA := req.Header.Get("x-amz-user-agent")
|
||||||
|
if amzUA == "" {
|
||||||
|
t.Error("expected x-amz-user-agent header to be set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(amzUA, "aws-sdk-js/") {
|
||||||
|
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
|
||||||
|
}
|
||||||
|
if !strings.Contains(amzUA, "KiroIDE") {
|
||||||
|
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
|
||||||
|
}
|
||||||
|
|
||||||
|
ua := req.Header.Get("User-Agent")
|
||||||
|
if ua == "" {
|
||||||
|
t.Error("expected User-Agent header to be set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(ua, "api/sso-oidc") {
|
||||||
|
t.Errorf("User-Agent should contain api name: %s", ua)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Header.Get("amz-sdk-invocation-id") == "" {
|
||||||
|
t.Error("expected amz-sdk-invocation-id header to be set")
|
||||||
|
}
|
||||||
|
if req.Header.Get("amz-sdk-request") != "attempt=1; max=4" {
|
||||||
|
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
endpoint string
|
||||||
|
path string
|
||||||
|
queryParams map[string]string
|
||||||
|
want string
|
||||||
|
wantContains []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no query params",
|
||||||
|
endpoint: "https://api.example.com",
|
||||||
|
path: "getUsageLimits",
|
||||||
|
queryParams: nil,
|
||||||
|
want: "https://api.example.com/getUsageLimits",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty query params",
|
||||||
|
endpoint: "https://api.example.com",
|
||||||
|
path: "getUsageLimits",
|
||||||
|
queryParams: map[string]string{},
|
||||||
|
want: "https://api.example.com/getUsageLimits",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single query param",
|
||||||
|
endpoint: "https://api.example.com",
|
||||||
|
path: "getUsageLimits",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
},
|
||||||
|
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple query params",
|
||||||
|
endpoint: "https://api.example.com",
|
||||||
|
path: "getUsageLimits",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
"resourceType": "AGENTIC_REQUEST",
|
||||||
|
"profileArn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEF",
|
||||||
|
},
|
||||||
|
wantContains: []string{
|
||||||
|
"https://api.example.com/getUsageLimits?",
|
||||||
|
"origin=AI_EDITOR",
|
||||||
|
"profileArn=arn%3Aaws%3Acodewhisperer%3Aus-east-1%3A123456789012%3Aprofile%2FABCDEF",
|
||||||
|
"resourceType=AGENTIC_REQUEST",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "omit empty params",
|
||||||
|
endpoint: "https://api.example.com",
|
||||||
|
path: "getUsageLimits",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
"profileArn": "",
|
||||||
|
},
|
||||||
|
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := buildURL(tt.endpoint, tt.path, tt.queryParams)
|
||||||
|
if tt.want != "" {
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("buildURL() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tt.wantContains != nil {
|
||||||
|
for _, substr := range tt.wantContains {
|
||||||
|
if !strings.Contains(got, substr) {
|
||||||
|
t.Errorf("buildURL() = %v, want to contain %v", got, substr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUserAgentFormat(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fp := fm.GetFingerprint("token1")
|
||||||
|
|
||||||
|
ua := fp.BuildUserAgent()
|
||||||
|
requiredParts := []string{
|
||||||
|
"aws-sdk-js/",
|
||||||
|
"ua/2.1",
|
||||||
|
"os/",
|
||||||
|
"lang/js",
|
||||||
|
"md/nodejs#",
|
||||||
|
"api/codewhispererstreaming#",
|
||||||
|
"m/E",
|
||||||
|
"KiroIDE-",
|
||||||
|
}
|
||||||
|
for _, part := range requiredParts {
|
||||||
|
if !strings.Contains(ua, part) {
|
||||||
|
t.Errorf("User-Agent missing required part %q: %s", part, ua)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAmzUserAgentFormat(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fp := fm.GetFingerprint("token1")
|
||||||
|
|
||||||
|
amzUA := fp.BuildAmzUserAgent()
|
||||||
|
requiredParts := []string{
|
||||||
|
"aws-sdk-js/",
|
||||||
|
"KiroIDE-",
|
||||||
|
}
|
||||||
|
for _, part := range requiredParts {
|
||||||
|
if !strings.Contains(amzUA, part) {
|
||||||
|
t.Errorf("X-Amz-User-Agent missing required part %q: %s", part, amzUA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Amz-User-Agent should be shorter than User-Agent
|
||||||
|
ua := fp.BuildUserAgent()
|
||||||
|
if len(amzUA) >= len(ua) {
|
||||||
|
t.Error("X-Amz-User-Agent should be shorter than User-Agent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRuntimeHeaders(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
accessToken := "test-access-token-1234567890"
|
||||||
|
clientID := "test-client-id-12345"
|
||||||
|
accountKey := GenerateAccountKey(clientID)
|
||||||
|
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||||
|
machineID := fp.KiroHash
|
||||||
|
|
||||||
|
setRuntimeHeaders(req, accessToken, accountKey)
|
||||||
|
|
||||||
|
// Check Authorization header
|
||||||
|
if req.Header.Get("Authorization") != "Bearer "+accessToken {
|
||||||
|
t.Errorf("expected Authorization header 'Bearer %s', got '%s'", accessToken, req.Header.Get("Authorization"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check x-amz-user-agent header
|
||||||
|
amzUA := req.Header.Get("x-amz-user-agent")
|
||||||
|
if amzUA == "" {
|
||||||
|
t.Error("expected x-amz-user-agent header to be set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(amzUA, "aws-sdk-js/") {
|
||||||
|
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
|
||||||
|
}
|
||||||
|
if !strings.Contains(amzUA, "KiroIDE-") {
|
||||||
|
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
|
||||||
|
}
|
||||||
|
if !strings.Contains(amzUA, machineID) {
|
||||||
|
t.Errorf("x-amz-user-agent should contain machineID: %s", amzUA)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check User-Agent header
|
||||||
|
ua := req.Header.Get("User-Agent")
|
||||||
|
if ua == "" {
|
||||||
|
t.Error("expected User-Agent header to be set")
|
||||||
|
}
|
||||||
|
if !strings.Contains(ua, "api/codewhispererruntime#") {
|
||||||
|
t.Errorf("User-Agent should contain api/codewhispererruntime: %s", ua)
|
||||||
|
}
|
||||||
|
if !strings.Contains(ua, "m/N,E") {
|
||||||
|
t.Errorf("User-Agent should contain m/N,E: %s", ua)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check amz-sdk-invocation-id (should be a UUID)
|
||||||
|
invocationID := req.Header.Get("amz-sdk-invocation-id")
|
||||||
|
if invocationID == "" {
|
||||||
|
t.Error("expected amz-sdk-invocation-id header to be set")
|
||||||
|
}
|
||||||
|
if len(invocationID) != 36 {
|
||||||
|
t.Errorf("expected amz-sdk-invocation-id to be UUID (36 chars), got %d", len(invocationID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check amz-sdk-request
|
||||||
|
if req.Header.Get("amz-sdk-request") != "attempt=1; max=1" {
|
||||||
|
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSDKVersionsAreValid(t *testing.T) {
|
||||||
|
// Verify all OIDC SDK versions match expected format (3.xxx.x)
|
||||||
|
for _, v := range oidcSDKVersions {
|
||||||
|
if !strings.HasPrefix(v, "3.") {
|
||||||
|
t.Errorf("OIDC SDK version should start with 3.: %s", v)
|
||||||
|
}
|
||||||
|
parts := strings.Split(v, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Errorf("OIDC SDK version should have 3 parts: %s", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range runtimeSDKVersions {
|
||||||
|
parts := strings.Split(v, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Errorf("Runtime SDK version should have 3 parts: %s", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range streamingSDKVersions {
|
||||||
|
parts := strings.Split(v, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Errorf("Streaming SDK version should have 3 parts: %s", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKiroVersionsAreValid(t *testing.T) {
|
||||||
|
// Verify all Kiro versions match expected format (0.x.xxx)
|
||||||
|
for _, v := range kiroVersions {
|
||||||
|
if !strings.HasPrefix(v, "0.") {
|
||||||
|
t.Errorf("Kiro version should start with 0.: %s", v)
|
||||||
|
}
|
||||||
|
parts := strings.Split(v, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Errorf("Kiro version should have 3 parts: %s", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNodeVersionsAreValid(t *testing.T) {
|
||||||
|
// Verify all Node versions match expected format (xx.xx.x)
|
||||||
|
for _, v := range nodeVersions {
|
||||||
|
parts := strings.Split(v, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Errorf("Node version should have 3 parts: %s", v)
|
||||||
|
}
|
||||||
|
// Should be Node 20.x or 22.x
|
||||||
|
if !strings.HasPrefix(v, "20.") && !strings.HasPrefix(v, "22.") {
|
||||||
|
t.Errorf("Node version should be 20.x or 22.x LTS: %s", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFingerprintManager_SetConfig(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
|
||||||
|
// Without config, should generate random fingerprint
|
||||||
|
fp1 := fm.GetFingerprint("token1")
|
||||||
|
if fp1 == nil {
|
||||||
|
t.Fatal("expected non-nil fingerprint")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set config with all fields
|
||||||
|
cfg := &FingerprintConfig{
|
||||||
|
OIDCSDKVersion: "3.999.0",
|
||||||
|
RuntimeSDKVersion: "9.9.9",
|
||||||
|
StreamingSDKVersion: "8.8.8",
|
||||||
|
OSType: "darwin",
|
||||||
|
OSVersion: "99.0.0",
|
||||||
|
NodeVersion: "99.99.99",
|
||||||
|
KiroVersion: "9.9.999",
|
||||||
|
KiroHash: "customhash123",
|
||||||
|
}
|
||||||
|
fm.SetConfig(cfg)
|
||||||
|
|
||||||
|
// After setting config, should use config values
|
||||||
|
fp2 := fm.GetFingerprint("token2")
|
||||||
|
if fp2.OIDCSDKVersion != "3.999.0" {
|
||||||
|
t.Errorf("expected OIDCSDKVersion '3.999.0', got '%s'", fp2.OIDCSDKVersion)
|
||||||
|
}
|
||||||
|
if fp2.RuntimeSDKVersion != "9.9.9" {
|
||||||
|
t.Errorf("expected RuntimeSDKVersion '9.9.9', got '%s'", fp2.RuntimeSDKVersion)
|
||||||
|
}
|
||||||
|
if fp2.StreamingSDKVersion != "8.8.8" {
|
||||||
|
t.Errorf("expected StreamingSDKVersion '8.8.8', got '%s'", fp2.StreamingSDKVersion)
|
||||||
|
}
|
||||||
|
if fp2.OSType != "darwin" {
|
||||||
|
t.Errorf("expected OSType 'darwin', got '%s'", fp2.OSType)
|
||||||
|
}
|
||||||
|
if fp2.OSVersion != "99.0.0" {
|
||||||
|
t.Errorf("expected OSVersion '99.0.0', got '%s'", fp2.OSVersion)
|
||||||
|
}
|
||||||
|
if fp2.NodeVersion != "99.99.99" {
|
||||||
|
t.Errorf("expected NodeVersion '99.99.99', got '%s'", fp2.NodeVersion)
|
||||||
|
}
|
||||||
|
if fp2.KiroVersion != "9.9.999" {
|
||||||
|
t.Errorf("expected KiroVersion '9.9.999', got '%s'", fp2.KiroVersion)
|
||||||
|
}
|
||||||
|
if fp2.KiroHash != "customhash123" {
|
||||||
|
t.Errorf("expected KiroHash 'customhash123', got '%s'", fp2.KiroHash)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFingerprintManager_SetConfig_PartialFields(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
|
||||||
|
// Set config with only some fields
|
||||||
|
cfg := &FingerprintConfig{
|
||||||
|
KiroVersion: "1.2.345",
|
||||||
|
KiroHash: "myhash",
|
||||||
|
// Other fields empty - should use random
|
||||||
|
}
|
||||||
|
fm.SetConfig(cfg)
|
||||||
|
|
||||||
|
fp := fm.GetFingerprint("token1")
|
||||||
|
|
||||||
|
// Configured fields should use config values
|
||||||
|
if fp.KiroVersion != "1.2.345" {
|
||||||
|
t.Errorf("expected KiroVersion '1.2.345', got '%s'", fp.KiroVersion)
|
||||||
|
}
|
||||||
|
if fp.KiroHash != "myhash" {
|
||||||
|
t.Errorf("expected KiroHash 'myhash', got '%s'", fp.KiroHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty fields should be randomly selected (non-empty)
|
||||||
|
if fp.OIDCSDKVersion == "" {
|
||||||
|
t.Error("expected non-empty OIDCSDKVersion")
|
||||||
|
}
|
||||||
|
if fp.OSType == "" {
|
||||||
|
t.Error("expected non-empty OSType")
|
||||||
|
}
|
||||||
|
if fp.NodeVersion == "" {
|
||||||
|
t.Error("expected non-empty NodeVersion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFingerprintManager_SetConfig_ClearsCache(t *testing.T) {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
|
||||||
|
// Get fingerprint before config
|
||||||
|
fp1 := fm.GetFingerprint("token1")
|
||||||
|
originalHash := fp1.KiroHash
|
||||||
|
|
||||||
|
// Set config
|
||||||
|
cfg := &FingerprintConfig{
|
||||||
|
KiroHash: "newcustomhash",
|
||||||
|
}
|
||||||
|
fm.SetConfig(cfg)
|
||||||
|
|
||||||
|
// Same token should now return different fingerprint (cache cleared)
|
||||||
|
fp2 := fm.GetFingerprint("token1")
|
||||||
|
if fp2.KiroHash == originalHash {
|
||||||
|
t.Error("expected cache to be cleared after SetConfig")
|
||||||
|
}
|
||||||
|
if fp2.KiroHash != "newcustomhash" {
|
||||||
|
t.Errorf("expected KiroHash 'newcustomhash', got '%s'", fp2.KiroHash)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateAccountKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
seed string
|
||||||
|
check func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty seed",
|
||||||
|
seed: "",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
if result == "" {
|
||||||
|
t.Error("expected non-empty result for empty seed")
|
||||||
|
}
|
||||||
|
if len(result) != 16 {
|
||||||
|
t.Errorf("expected 16 char hex string, got %d chars", len(result))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple seed",
|
||||||
|
seed: "test-client-id",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
if len(result) != 16 {
|
||||||
|
t.Errorf("expected 16 char hex string, got %d chars", len(result))
|
||||||
|
}
|
||||||
|
// Verify it's valid hex
|
||||||
|
for _, c := range result {
|
||||||
|
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
|
||||||
|
t.Errorf("invalid hex character: %c", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Same seed produces same result",
|
||||||
|
seed: "deterministic-seed",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
result2 := GenerateAccountKey("deterministic-seed")
|
||||||
|
if result != result2 {
|
||||||
|
t.Errorf("same seed should produce same result: %s vs %s", result, result2)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Different seeds produce different results",
|
||||||
|
seed: "seed-one",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
result2 := GenerateAccountKey("seed-two")
|
||||||
|
if result == result2 {
|
||||||
|
t.Errorf("different seeds should produce different results: %s vs %s", result, result2)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GenerateAccountKey(tt.seed)
|
||||||
|
tt.check(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
clientID string
|
||||||
|
refreshToken string
|
||||||
|
check func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Priority 1: clientID when both provided",
|
||||||
|
clientID: "client-id-123",
|
||||||
|
refreshToken: "refresh-token-456",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
expected := GenerateAccountKey("client-id-123")
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("expected clientID-based key %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Priority 2: refreshToken when clientID is empty",
|
||||||
|
clientID: "",
|
||||||
|
refreshToken: "refresh-token-789",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
expected := GenerateAccountKey("refresh-token-789")
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Priority 3: random when both empty",
|
||||||
|
clientID: "",
|
||||||
|
refreshToken: "",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
if len(result) != 16 {
|
||||||
|
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||||
|
}
|
||||||
|
// Should be different each time (random UUID)
|
||||||
|
result2 := GetAccountKey("", "")
|
||||||
|
if result == result2 {
|
||||||
|
t.Log("warning: random keys are the same (possible but unlikely)")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "clientID only",
|
||||||
|
clientID: "solo-client-id",
|
||||||
|
refreshToken: "",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
expected := GenerateAccountKey("solo-client-id")
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("expected clientID-based key %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "refreshToken only",
|
||||||
|
clientID: "",
|
||||||
|
refreshToken: "solo-refresh-token",
|
||||||
|
check: func(t *testing.T, result string) {
|
||||||
|
expected := GenerateAccountKey("solo-refresh-token")
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetAccountKey(tt.clientID, tt.refreshToken)
|
||||||
|
tt.check(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountKey_Deterministic(t *testing.T) {
|
||||||
|
// Verify that GetAccountKey produces deterministic results for same inputs
|
||||||
|
clientID := "test-client-id-abc"
|
||||||
|
refreshToken := "test-refresh-token-xyz"
|
||||||
|
|
||||||
|
// Call multiple times with same inputs
|
||||||
|
results := make([]string, 10)
|
||||||
|
for i := range 10 {
|
||||||
|
results[i] = GetAccountKey(clientID, refreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All results should be identical
|
||||||
|
for i := 1; i < 10; i++ {
|
||||||
|
if results[i] != results[0] {
|
||||||
|
t.Errorf("GetAccountKey should be deterministic: got %s and %s", results[0], results[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFingerprintDeterministic(t *testing.T) {
|
||||||
|
// Verify that fingerprints are deterministic based on accountKey
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
|
||||||
|
accountKey := GenerateAccountKey("test-client-id")
|
||||||
|
|
||||||
|
// Get fingerprint multiple times
|
||||||
|
fp1 := fm.GetFingerprint(accountKey)
|
||||||
|
fp2 := fm.GetFingerprint(accountKey)
|
||||||
|
|
||||||
|
// Should be the same pointer (cached)
|
||||||
|
if fp1 != fp2 {
|
||||||
|
t.Error("expected same fingerprint pointer for same key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new manager and verify same values
|
||||||
|
fm2 := NewFingerprintManager()
|
||||||
|
fp3 := fm2.GetFingerprint(accountKey)
|
||||||
|
|
||||||
|
// Values should be identical (deterministic generation)
|
||||||
|
if fp1.KiroHash != fp3.KiroHash {
|
||||||
|
t.Errorf("KiroHash should be deterministic: %s vs %s", fp1.KiroHash, fp3.KiroHash)
|
||||||
|
}
|
||||||
|
if fp1.OSType != fp3.OSType {
|
||||||
|
t.Errorf("OSType should be deterministic: %s vs %s", fp1.OSType, fp3.OSType)
|
||||||
|
}
|
||||||
|
if fp1.OSVersion != fp3.OSVersion {
|
||||||
|
t.Errorf("OSVersion should be deterministic: %s vs %s", fp1.OSVersion, fp3.OSVersion)
|
||||||
|
}
|
||||||
|
if fp1.KiroVersion != fp3.KiroVersion {
|
||||||
|
t.Errorf("KiroVersion should be deterministic: %s vs %s", fp1.KiroVersion, fp3.KiroVersion)
|
||||||
|
}
|
||||||
|
if fp1.NodeVersion != fp3.NodeVersion {
|
||||||
|
t.Errorf("NodeVersion should be deterministic: %s vs %s", fp1.NodeVersion, fp3.NodeVersion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,10 +23,10 @@ import (
|
|||||||
const (
|
const (
|
||||||
// Kiro auth endpoint
|
// Kiro auth endpoint
|
||||||
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||||
|
|
||||||
// Default callback port
|
// Default callback port
|
||||||
defaultCallbackPort = 9876
|
defaultCallbackPort = 9876
|
||||||
|
|
||||||
// Auth timeout
|
// Auth timeout
|
||||||
authTimeout = 10 * time.Minute
|
authTimeout = 10 * time.Minute
|
||||||
)
|
)
|
||||||
@@ -41,8 +41,10 @@ type KiroTokenResponse struct {
|
|||||||
|
|
||||||
// KiroOAuth handles the OAuth flow for Kiro authentication.
|
// KiroOAuth handles the OAuth flow for Kiro authentication.
|
||||||
type KiroOAuth struct {
|
type KiroOAuth struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
machineID string
|
||||||
|
kiroVersion string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKiroOAuth creates a new Kiro OAuth handler.
|
// NewKiroOAuth creates a new Kiro OAuth handler.
|
||||||
@@ -51,9 +53,12 @@ func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
|
|||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
}
|
}
|
||||||
|
fp := GlobalFingerprintManager().GetFingerprint("login")
|
||||||
return &KiroOAuth{
|
return &KiroOAuth{
|
||||||
httpClient: client,
|
httpClient: client,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
machineID: fp.KiroHash,
|
||||||
|
kiroVersion: fp.KiroVersion,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +195,8 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
|
|||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
|
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
|
||||||
resp, err := o.httpClient.Do(req)
|
resp, err := o.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -256,11 +262,8 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke
|
|||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
|
||||||
// Use KiroIDE-style User-Agent to match official Kiro IDE behavior
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
// This helps avoid 403 errors from server-side User-Agent validation
|
|
||||||
userAgent := buildKiroUserAgent(tokenKey)
|
|
||||||
req.Header.Set("User-Agent", userAgent)
|
|
||||||
|
|
||||||
resp, err := o.httpClient.Do(req)
|
resp, err := o.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -301,19 +304,6 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildKiroUserAgent builds a KiroIDE-style User-Agent string.
|
|
||||||
// If tokenKey is provided, uses fingerprint manager for consistent fingerprint.
|
|
||||||
// Otherwise generates a simple KiroIDE User-Agent.
|
|
||||||
func buildKiroUserAgent(tokenKey string) string {
|
|
||||||
if tokenKey != "" {
|
|
||||||
fm := NewFingerprintManager()
|
|
||||||
fp := fm.GetFingerprint(tokenKey)
|
|
||||||
return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16])
|
|
||||||
}
|
|
||||||
// Default KiroIDE User-Agent matching kiro-openai-gateway format
|
|
||||||
return "KiroIDE-0.7.45-cli-proxy-api"
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
||||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
|||||||
@@ -35,35 +35,35 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type webAuthSession struct {
|
type webAuthSession struct {
|
||||||
stateID string
|
stateID string
|
||||||
deviceCode string
|
deviceCode string
|
||||||
userCode string
|
userCode string
|
||||||
authURL string
|
authURL string
|
||||||
verificationURI string
|
verificationURI string
|
||||||
expiresIn int
|
expiresIn int
|
||||||
interval int
|
interval int
|
||||||
status authSessionStatus
|
status authSessionStatus
|
||||||
startedAt time.Time
|
startedAt time.Time
|
||||||
completedAt time.Time
|
completedAt time.Time
|
||||||
expiresAt time.Time
|
expiresAt time.Time
|
||||||
error string
|
error string
|
||||||
tokenData *KiroTokenData
|
tokenData *KiroTokenData
|
||||||
ssoClient *SSOOIDCClient
|
ssoClient *SSOOIDCClient
|
||||||
clientID string
|
clientID string
|
||||||
clientSecret string
|
clientSecret string
|
||||||
region string
|
region string
|
||||||
cancelFunc context.CancelFunc
|
cancelFunc context.CancelFunc
|
||||||
authMethod string // "google", "github", "builder-id", "idc"
|
authMethod string // "google", "github", "builder-id", "idc"
|
||||||
startURL string // Used for IDC
|
startURL string // Used for IDC
|
||||||
codeVerifier string // Used for social auth PKCE
|
codeVerifier string // Used for social auth PKCE
|
||||||
codeChallenge string // Used for social auth PKCE
|
codeChallenge string // Used for social auth PKCE
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthWebHandler struct {
|
type OAuthWebHandler struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
sessions map[string]*webAuthSession
|
sessions map[string]*webAuthSession
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
onTokenObtained func(*KiroTokenData)
|
onTokenObtained func(*KiroTokenData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||||
@@ -104,7 +104,7 @@ func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||||
method := c.Query("method")
|
method := c.Query("method")
|
||||||
|
|
||||||
if method == "" {
|
if method == "" {
|
||||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
||||||
return
|
return
|
||||||
@@ -138,7 +138,7 @@ func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
socialClient := NewSocialAuthClient(h.cfg)
|
socialClient := NewSocialAuthClient(h.cfg)
|
||||||
|
|
||||||
var provider string
|
var provider string
|
||||||
if method == "google" {
|
if method == "google" {
|
||||||
provider = string(ProviderGoogle)
|
provider = string(ProviderGoogle)
|
||||||
@@ -373,22 +373,28 @@ func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSess
|
|||||||
}
|
}
|
||||||
|
|
||||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
|
||||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
// Fetch profileArn for IDC
|
||||||
|
var profileArn string
|
||||||
|
if session.authMethod == "idc" {
|
||||||
|
profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
|
||||||
|
|
||||||
tokenData := &KiroTokenData{
|
tokenData := &KiroTokenData{
|
||||||
AccessToken: tokenResp.AccessToken,
|
AccessToken: tokenResp.AccessToken,
|
||||||
RefreshToken: tokenResp.RefreshToken,
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
ProfileArn: profileArn,
|
ProfileArn: profileArn,
|
||||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
AuthMethod: session.authMethod,
|
AuthMethod: session.authMethod,
|
||||||
Provider: "AWS",
|
Provider: "AWS",
|
||||||
ClientID: session.clientID,
|
ClientID: session.clientID,
|
||||||
ClientSecret: session.clientSecret,
|
ClientSecret: session.clientSecret,
|
||||||
Email: email,
|
Email: email,
|
||||||
Region: session.region,
|
Region: session.region,
|
||||||
StartURL: session.startURL,
|
StartURL: session.startURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
session.status = statusSuccess
|
session.status = statusSuccess
|
||||||
@@ -442,7 +448,7 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
|||||||
fileName := GenerateTokenFileName(tokenData)
|
fileName := GenerateTokenFileName(tokenData)
|
||||||
|
|
||||||
authFilePath := filepath.Join(authDir, fileName)
|
authFilePath := filepath.Join(authDir, fileName)
|
||||||
|
|
||||||
// Convert to storage format and save
|
// Convert to storage format and save
|
||||||
storage := &KiroTokenStorage{
|
storage := &KiroTokenStorage{
|
||||||
Type: "kiro",
|
Type: "kiro",
|
||||||
@@ -459,12 +465,12 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
|||||||
StartURL: tokenData.StartURL,
|
StartURL: tokenData.StartURL,
|
||||||
Email: tokenData.Email,
|
Email: tokenData.Email,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
||||||
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,14 +10,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RefreshManager 是后台刷新器的单例管理器
|
// RefreshManager is a singleton manager for background token refreshing.
|
||||||
type RefreshManager struct {
|
type RefreshManager struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
refresher *BackgroundRefresher
|
refresher *BackgroundRefresher
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
started bool
|
started bool
|
||||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -25,7 +25,7 @@ var (
|
|||||||
managerOnce sync.Once
|
managerOnce sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetRefreshManager 获取全局刷新管理器实例
|
// GetRefreshManager returns the global RefreshManager singleton.
|
||||||
func GetRefreshManager() *RefreshManager {
|
func GetRefreshManager() *RefreshManager {
|
||||||
managerOnce.Do(func() {
|
managerOnce.Do(func() {
|
||||||
globalRefreshManager = &RefreshManager{}
|
globalRefreshManager = &RefreshManager{}
|
||||||
@@ -33,9 +33,7 @@ func GetRefreshManager() *RefreshManager {
|
|||||||
return globalRefreshManager
|
return globalRefreshManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize 初始化后台刷新器
|
// Initialize sets up the background refresher.
|
||||||
// baseDir: token 文件所在的目录
|
|
||||||
// cfg: 应用配置
|
|
||||||
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -58,18 +56,16 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
|||||||
baseDir = resolvedBaseDir
|
baseDir = resolvedBaseDir
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建 token 存储库
|
|
||||||
repo := NewFileTokenRepository(baseDir)
|
repo := NewFileTokenRepository(baseDir)
|
||||||
|
|
||||||
// 创建后台刷新器,配置参数
|
|
||||||
opts := []RefresherOption{
|
opts := []RefresherOption{
|
||||||
WithInterval(time.Minute), // 每分钟检查一次
|
WithInterval(time.Minute),
|
||||||
WithBatchSize(50), // 每批最多处理 50 个 token
|
WithBatchSize(50),
|
||||||
WithConcurrency(10), // 最多 10 个并发刷新
|
WithConcurrency(10),
|
||||||
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
|
WithConfig(cfg),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果已设置回调,传递给 BackgroundRefresher
|
// Pass callback to BackgroundRefresher if already set
|
||||||
if m.onTokenRefreshed != nil {
|
if m.onTokenRefreshed != nil {
|
||||||
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
|
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
|
||||||
}
|
}
|
||||||
@@ -80,7 +76,7 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start 启动后台刷新
|
// Start begins background token refreshing.
|
||||||
func (m *RefreshManager) Start() {
|
func (m *RefreshManager) Start() {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -102,7 +98,7 @@ func (m *RefreshManager) Start() {
|
|||||||
log.Info("refresh manager: background refresh started")
|
log.Info("refresh manager: background refresh started")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop 停止后台刷新
|
// Stop halts background token refreshing.
|
||||||
func (m *RefreshManager) Stop() {
|
func (m *RefreshManager) Stop() {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -123,14 +119,14 @@ func (m *RefreshManager) Stop() {
|
|||||||
log.Info("refresh manager: background refresh stopped")
|
log.Info("refresh manager: background refresh stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsRunning 检查后台刷新是否正在运行
|
// IsRunning reports whether background refreshing is active.
|
||||||
func (m *RefreshManager) IsRunning() bool {
|
func (m *RefreshManager) IsRunning() bool {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
return m.started
|
return m.started
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
|
// UpdateBaseDir changes the token directory at runtime.
|
||||||
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -143,16 +139,15 @@ func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
|
// SetOnTokenRefreshed registers a callback invoked after a successful token refresh.
|
||||||
// 可以在任何时候调用,支持运行时更新回调
|
// Can be called at any time; supports runtime callback updates.
|
||||||
// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据
|
|
||||||
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
|
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
m.onTokenRefreshed = callback
|
m.onTokenRefreshed = callback
|
||||||
|
|
||||||
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
|
// Update the refresher's callback in a thread-safe manner if already created
|
||||||
if m.refresher != nil {
|
if m.refresher != nil {
|
||||||
m.refresher.callbackMu.Lock()
|
m.refresher.callbackMu.Lock()
|
||||||
m.refresher.onTokenRefreshed = callback
|
m.refresher.onTokenRefreshed = callback
|
||||||
@@ -162,8 +157,11 @@ func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, token
|
|||||||
log.Debug("refresh manager: token refresh callback registered")
|
log.Debug("refresh manager: token refresh callback registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
|
// InitializeAndStart initializes and starts background refreshing (convenience method).
|
||||||
func InitializeAndStart(baseDir string, cfg *config.Config) {
|
func InitializeAndStart(baseDir string, cfg *config.Config) {
|
||||||
|
// Initialize global fingerprint config
|
||||||
|
initGlobalFingerprintConfig(cfg)
|
||||||
|
|
||||||
manager := GetRefreshManager()
|
manager := GetRefreshManager()
|
||||||
if err := manager.Initialize(baseDir, cfg); err != nil {
|
if err := manager.Initialize(baseDir, cfg); err != nil {
|
||||||
log.Errorf("refresh manager: initialization failed: %v", err)
|
log.Errorf("refresh manager: initialization failed: %v", err)
|
||||||
@@ -172,7 +170,31 @@ func InitializeAndStart(baseDir string, cfg *config.Config) {
|
|||||||
manager.Start()
|
manager.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
// StopGlobalRefreshManager 停止全局刷新管理器
|
// initGlobalFingerprintConfig loads fingerprint settings from application config.
|
||||||
|
func initGlobalFingerprintConfig(cfg *config.Config) {
|
||||||
|
if cfg == nil || cfg.KiroFingerprint == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fpCfg := cfg.KiroFingerprint
|
||||||
|
SetGlobalFingerprintConfig(&FingerprintConfig{
|
||||||
|
OIDCSDKVersion: fpCfg.OIDCSDKVersion,
|
||||||
|
RuntimeSDKVersion: fpCfg.RuntimeSDKVersion,
|
||||||
|
StreamingSDKVersion: fpCfg.StreamingSDKVersion,
|
||||||
|
OSType: fpCfg.OSType,
|
||||||
|
OSVersion: fpCfg.OSVersion,
|
||||||
|
NodeVersion: fpCfg.NodeVersion,
|
||||||
|
KiroVersion: fpCfg.KiroVersion,
|
||||||
|
KiroHash: fpCfg.KiroHash,
|
||||||
|
})
|
||||||
|
log.Debug("kiro: global fingerprint config loaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitFingerprintConfig initializes the global fingerprint config from application config.
|
||||||
|
func InitFingerprintConfig(cfg *config.Config) {
|
||||||
|
initGlobalFingerprintConfig(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopGlobalRefreshManager stops the global refresh manager.
|
||||||
func StopGlobalRefreshManager() {
|
func StopGlobalRefreshManager() {
|
||||||
if globalRefreshManager != nil {
|
if globalRefreshManager != nil {
|
||||||
globalRefreshManager.Stop()
|
globalRefreshManager.Stop()
|
||||||
|
|||||||
@@ -84,6 +84,8 @@ type SocialAuthClient struct {
|
|||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
protocolHandler *ProtocolHandler
|
protocolHandler *ProtocolHandler
|
||||||
|
machineID string
|
||||||
|
kiroVersion string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSocialAuthClient creates a new social auth client.
|
// NewSocialAuthClient creates a new social auth client.
|
||||||
@@ -92,10 +94,13 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
|||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
}
|
}
|
||||||
|
fp := GlobalFingerprintManager().GetFingerprint("login")
|
||||||
return &SocialAuthClient{
|
return &SocialAuthClient{
|
||||||
httpClient: client,
|
httpClient: client,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
protocolHandler: NewProtocolHandler(),
|
protocolHandler: NewProtocolHandler(),
|
||||||
|
machineID: fp.KiroHash,
|
||||||
|
kiroVersion: fp.KiroVersion,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,7 +234,8 @@ func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequ
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
|
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
|
||||||
|
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(httpReq)
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -269,7 +275,8 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
|
||||||
|
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(httpReq)
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -466,7 +473,7 @@ func forceDefaultProtocolHandler() {
|
|||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
return // Non-Linux platforms use different handler mechanisms
|
return // Non-Linux platforms use different handler mechanisms
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set our handler as default using xdg-mime
|
// Set our handler as default using xdg-mime
|
||||||
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -40,21 +41,13 @@ const (
|
|||||||
// Authorization code flow callback
|
// Authorization code flow callback
|
||||||
authCodeCallbackPath = "/oauth/callback"
|
authCodeCallbackPath = "/oauth/callback"
|
||||||
authCodeCallbackPort = 19877
|
authCodeCallbackPort = 19877
|
||||||
|
|
||||||
// User-Agent to match official Kiro IDE
|
|
||||||
kiroUserAgent = "KiroIDE"
|
|
||||||
|
|
||||||
// IDC token refresh headers (matching Kiro IDE behavior)
|
|
||||||
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sentinel errors for OIDC token polling
|
|
||||||
var (
|
var (
|
||||||
ErrAuthorizationPending = errors.New("authorization_pending")
|
ErrAuthorizationPending = errors.New("authorization_pending")
|
||||||
ErrSlowDown = errors.New("slow_down")
|
ErrSlowDown = errors.New("slow_down")
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
|
||||||
type SSOOIDCClient struct {
|
type SSOOIDCClient struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -74,10 +67,10 @@ func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient {
|
|||||||
|
|
||||||
// RegisterClientResponse from AWS SSO OIDC.
|
// RegisterClientResponse from AWS SSO OIDC.
|
||||||
type RegisterClientResponse struct {
|
type RegisterClientResponse struct {
|
||||||
ClientID string `json:"clientId"`
|
ClientID string `json:"clientId"`
|
||||||
ClientSecret string `json:"clientSecret"`
|
ClientSecret string `json:"clientSecret"`
|
||||||
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
|
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
|
||||||
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
|
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartDeviceAuthResponse from AWS SSO OIDC.
|
// StartDeviceAuthResponse from AWS SSO OIDC.
|
||||||
@@ -174,8 +167,7 @@ func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region str
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
SetOIDCHeaders(req)
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -220,8 +212,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, cli
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
SetOIDCHeaders(req)
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -267,8 +258,7 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
SetOIDCHeaders(req)
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -311,8 +301,11 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
|
|||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
|
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific OIDC region.
|
||||||
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
|
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
endpoint := getOIDCEndpoint(region)
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
payload := map[string]string{
|
payload := map[string]string{
|
||||||
@@ -331,18 +324,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
SetOIDCHeaders(req)
|
||||||
// Set headers matching kiro2api's IDC token refresh
|
|
||||||
// These headers are required for successful IDC token refresh
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
|
|
||||||
req.Header.Set("Connection", "keep-alive")
|
|
||||||
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
|
|
||||||
req.Header.Set("Accept", "*/*")
|
|
||||||
req.Header.Set("Accept-Language", "*")
|
|
||||||
req.Header.Set("sec-fetch-mode", "cors")
|
|
||||||
req.Header.Set("User-Agent", "node")
|
|
||||||
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -469,10 +451,10 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
|
|||||||
|
|
||||||
// Step 5: Get profile ARN from CodeWhisperer API
|
// Step 5: Get profile ARN from CodeWhisperer API
|
||||||
fmt.Println("Fetching profile information...")
|
fmt.Println("Fetching profile information...")
|
||||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||||
|
|
||||||
// Fetch user email
|
// Fetch user email
|
||||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||||
if email != "" {
|
if email != "" {
|
||||||
fmt.Printf(" Logged in as: %s\n", email)
|
fmt.Printf(" Logged in as: %s\n", email)
|
||||||
}
|
}
|
||||||
@@ -502,12 +484,36 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
|
|||||||
return nil, fmt.Errorf("authorization timed out")
|
return nil, fmt.Errorf("authorization timed out")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IDCLoginOptions holds optional parameters for IDC login.
|
||||||
|
type IDCLoginOptions struct {
|
||||||
|
StartURL string // Pre-configured start URL (skips prompt if set)
|
||||||
|
Region string // OIDC region for login and token refresh (defaults to us-east-1)
|
||||||
|
UseDeviceCode bool // Use Device Code flow instead of Auth Code flow
|
||||||
|
}
|
||||||
|
|
||||||
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
|
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
|
||||||
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
|
// Options can be provided to pre-configure IDC parameters (startURL, region).
|
||||||
|
// If StartURL is provided in opts, IDC flow is used directly without prompting.
|
||||||
|
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context, opts *IDCLoginOptions) (*KiroTokenData, error) {
|
||||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
fmt.Println("║ Kiro Authentication (AWS) ║")
|
fmt.Println("║ Kiro Authentication (AWS) ║")
|
||||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// If IDC options with StartURL are provided, skip method selection and use IDC directly
|
||||||
|
if opts != nil && opts.StartURL != "" {
|
||||||
|
region := opts.Region
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
fmt.Printf("\n Using IDC with Start URL: %s\n", opts.StartURL)
|
||||||
|
fmt.Printf(" Region: %s\n", region)
|
||||||
|
|
||||||
|
if opts.UseDeviceCode {
|
||||||
|
return c.LoginWithIDCAndOptions(ctx, opts.StartURL, region)
|
||||||
|
}
|
||||||
|
return c.LoginWithIDCAuthCode(ctx, opts.StartURL, region)
|
||||||
|
}
|
||||||
|
|
||||||
// Prompt for login method
|
// Prompt for login method
|
||||||
options := []string{
|
options := []string{
|
||||||
"Use with Builder ID (personal AWS account)",
|
"Use with Builder ID (personal AWS account)",
|
||||||
@@ -520,15 +526,41 @@ func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroToke
|
|||||||
return c.LoginWithBuilderID(ctx)
|
return c.LoginWithBuilderID(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IDC flow - prompt for start URL and region
|
// IDC flow - use pre-configured values or prompt
|
||||||
fmt.Println()
|
var startURL, region string
|
||||||
startURL := promptInput("? Enter Start URL", "")
|
|
||||||
if startURL == "" {
|
if opts != nil {
|
||||||
return nil, fmt.Errorf("start URL is required for IDC login")
|
startURL = opts.StartURL
|
||||||
|
region = opts.Region
|
||||||
}
|
}
|
||||||
|
|
||||||
region := promptInput("? Enter Region", defaultIDCRegion)
|
fmt.Println()
|
||||||
|
|
||||||
|
// Use pre-configured startURL or prompt
|
||||||
|
if startURL == "" {
|
||||||
|
startURL = promptInput("? Enter Start URL", "")
|
||||||
|
if startURL == "" {
|
||||||
|
return nil, fmt.Errorf("start URL is required for IDC login")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" Using pre-configured Start URL: %s\n", startURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pre-configured region or prompt
|
||||||
|
if region == "" {
|
||||||
|
region = promptInput("? Enter Region", defaultIDCRegion)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" Using pre-configured Region: %s\n", region)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts != nil && opts.UseDeviceCode {
|
||||||
|
return c.LoginWithIDCAndOptions(ctx, startURL, region)
|
||||||
|
}
|
||||||
|
return c.LoginWithIDCAuthCode(ctx, startURL, region)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithIDCAndOptions performs IDC login with the specified region.
|
||||||
|
func (c *SSOOIDCClient) LoginWithIDCAndOptions(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
|
||||||
return c.LoginWithIDC(ctx, startURL, region)
|
return c.LoginWithIDC(ctx, startURL, region)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -550,8 +582,7 @@ func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResp
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
SetOIDCHeaders(req)
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -594,8 +625,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
SetOIDCHeaders(req)
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -639,8 +669,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
SetOIDCHeaders(req)
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -702,13 +731,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
SetOIDCHeaders(req)
|
||||||
// Set headers matching Kiro IDE behavior for better compatibility
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Host", "oidc.us-east-1.amazonaws.com")
|
|
||||||
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
|
|
||||||
req.Header.Set("User-Agent", "node")
|
|
||||||
req.Header.Set("Accept", "*/*")
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -835,12 +858,8 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
|||||||
log.Debugf("Failed to close browser: %v", err)
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 5: Get profile ARN from CodeWhisperer API
|
|
||||||
fmt.Println("Fetching profile information...")
|
|
||||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
|
||||||
|
|
||||||
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
||||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||||
if email != "" {
|
if email != "" {
|
||||||
fmt.Printf(" Logged in as: %s\n", email)
|
fmt.Printf(" Logged in as: %s\n", email)
|
||||||
}
|
}
|
||||||
@@ -850,7 +869,7 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
|||||||
return &KiroTokenData{
|
return &KiroTokenData{
|
||||||
AccessToken: tokenResp.AccessToken,
|
AccessToken: tokenResp.AccessToken,
|
||||||
RefreshToken: tokenResp.RefreshToken,
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
ProfileArn: profileArn,
|
ProfileArn: "", // Builder ID has no profile
|
||||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
AuthMethod: "builder-id",
|
AuthMethod: "builder-id",
|
||||||
Provider: "AWS",
|
Provider: "AWS",
|
||||||
@@ -859,15 +878,15 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
|||||||
Email: email,
|
Email: email,
|
||||||
Region: defaultIDCRegion,
|
Region: defaultIDCRegion,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close browser on timeout for better UX
|
// Close browser on timeout for better UX
|
||||||
if err := browser.CloseBrowser(); err != nil {
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("authorization timed out")
|
return nil, fmt.Errorf("authorization timed out")
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
||||||
// Falls back to JWT parsing if userinfo fails.
|
// Falls back to JWT parsing if userinfo fails.
|
||||||
@@ -931,20 +950,64 @@ func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken str
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
|
// FetchProfileArn fetches the profile ARN from ListAvailableProfiles API.
|
||||||
// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
|
// This is used to get profileArn for imported accounts that may not have it.
|
||||||
func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
|
func (c *SSOOIDCClient) FetchProfileArn(ctx context.Context, accessToken, clientID, refreshToken string) string {
|
||||||
// Try ListProfiles API first
|
profileArn := c.tryListAvailableProfiles(ctx, accessToken, clientID, refreshToken)
|
||||||
profileArn := c.tryListProfiles(ctx, accessToken)
|
|
||||||
if profileArn != "" {
|
if profileArn != "" {
|
||||||
return profileArn
|
return profileArn
|
||||||
}
|
}
|
||||||
|
return c.tryListProfilesLegacy(ctx, accessToken)
|
||||||
// Fallback: Try ListAvailableCustomizations
|
|
||||||
return c.tryListCustomizations(ctx, accessToken)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string {
|
func (c *SSOOIDCClient) tryListAvailableProfiles(ctx context.Context, accessToken, clientID, refreshToken string) string {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetKiroAPIEndpoint("")+"/ListAvailableProfiles", strings.NewReader("{}"))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
accountKey := GetAccountKey(clientID, refreshToken)
|
||||||
|
setRuntimeHeaders(req, accessToken, accountKey)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("ListAvailableProfiles request failed: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("ListAvailableProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("ListAvailableProfiles response: %s", string(respBody))
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Profiles []struct {
|
||||||
|
Arn string `json:"arn"`
|
||||||
|
ProfileName string `json:"profileName"`
|
||||||
|
} `json:"profiles"`
|
||||||
|
NextToken *string `json:"nextToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
log.Debugf("ListAvailableProfiles parse error: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Profiles) > 0 {
|
||||||
|
log.Debugf("Found profile: %s (%s)", result.Profiles[0].ProfileName, result.Profiles[0].Arn)
|
||||||
|
return result.Profiles[0].Arn
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSOOIDCClient) tryListProfilesLegacy(ctx context.Context, accessToken string) string {
|
||||||
payload := map[string]interface{}{
|
payload := map[string]interface{}{
|
||||||
"origin": "AI_EDITOR",
|
"origin": "AI_EDITOR",
|
||||||
}
|
}
|
||||||
@@ -954,7 +1017,9 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
|
// Use the legacy CodeWhisperer endpoint for JSON-RPC style requests.
|
||||||
|
// The Q endpoint (q.{region}.amazonaws.com) does NOT support x-amz-target headers.
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetCodeWhispererLegacyEndpoint(""), strings.NewReader(string(body)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -973,11 +1038,11 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
|
|||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
|
log.Debugf("ListProfiles (legacy) failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("ListProfiles response: %s", string(respBody))
|
log.Debugf("ListProfiles (legacy) response: %s", string(respBody))
|
||||||
|
|
||||||
var result struct {
|
var result struct {
|
||||||
Profiles []struct {
|
Profiles []struct {
|
||||||
@@ -1001,63 +1066,6 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string {
|
|
||||||
payload := map[string]interface{}{
|
|
||||||
"origin": "AI_EDITOR",
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
|
||||||
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody))
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("ListAvailableCustomizations response: %s", string(respBody))
|
|
||||||
|
|
||||||
var result struct {
|
|
||||||
Customizations []struct {
|
|
||||||
Arn string `json:"arn"`
|
|
||||||
} `json:"customizations"`
|
|
||||||
ProfileArn string `json:"profileArn"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.ProfileArn != "" {
|
|
||||||
return result.ProfileArn
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(result.Customizations) > 0 {
|
|
||||||
return result.Customizations[0].Arn
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterClientForAuthCode registers a new OIDC client for authorization code flow.
|
// RegisterClientForAuthCode registers a new OIDC client for authorization code flow.
|
||||||
func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) {
|
func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) {
|
||||||
payload := map[string]interface{}{
|
payload := map[string]interface{}{
|
||||||
@@ -1078,8 +1086,7 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
SetOIDCHeaders(req)
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1105,6 +1112,53 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU
|
|||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *SSOOIDCClient) RegisterClientForAuthCodeWithIDC(ctx context.Context, redirectURI, issuerUrl, region string) (*RegisterClientResponse, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"clientName": "Kiro IDE",
|
||||||
|
"clientType": "public",
|
||||||
|
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||||
|
"grantTypes": []string{"authorization_code", "refresh_token"},
|
||||||
|
"redirectUris": []string{redirectURI},
|
||||||
|
"issuerUrl": issuerUrl,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
SetOIDCHeaders(req)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("register client for auth code with IDC failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result RegisterClientResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// AuthCodeCallbackResult contains the result from authorization code callback.
|
// AuthCodeCallbackResult contains the result from authorization code callback.
|
||||||
type AuthCodeCallbackResult struct {
|
type AuthCodeCallbackResult struct {
|
||||||
Code string
|
Code string
|
||||||
@@ -1128,6 +1182,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
|||||||
port := listener.Addr().(*net.TCPAddr).Port
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath)
|
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath)
|
||||||
resultChan := make(chan AuthCodeCallbackResult, 1)
|
resultChan := make(chan AuthCodeCallbackResult, 1)
|
||||||
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
ReadHeaderTimeout: 10 * time.Second,
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
@@ -1147,6 +1202,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
|||||||
<html><head><title>Login Failed</title></head>
|
<html><head><title>Login Failed</title></head>
|
||||||
<body><h1>Login Failed</h1><p>Error: %s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
<body><h1>Login Failed</h1><p>Error: %s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||||
resultChan <- AuthCodeCallbackResult{Error: errParam}
|
resultChan <- AuthCodeCallbackResult{Error: errParam}
|
||||||
|
close(doneChan)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1156,6 +1212,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
|||||||
<html><head><title>Login Failed</title></head>
|
<html><head><title>Login Failed</title></head>
|
||||||
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||||
resultChan <- AuthCodeCallbackResult{Error: "state mismatch"}
|
resultChan <- AuthCodeCallbackResult{Error: "state mismatch"}
|
||||||
|
close(doneChan)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1164,6 +1221,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
|||||||
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||||
<script>window.close();</script></body></html>`)
|
<script>window.close();</script></body></html>`)
|
||||||
resultChan <- AuthCodeCallbackResult{Code: code, State: state}
|
resultChan <- AuthCodeCallbackResult{Code: code, State: state}
|
||||||
|
close(doneChan)
|
||||||
})
|
})
|
||||||
|
|
||||||
server.Handler = mux
|
server.Handler = mux
|
||||||
@@ -1178,7 +1236,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
case <-time.After(10 * time.Minute):
|
case <-time.After(10 * time.Minute):
|
||||||
case <-resultChan:
|
case <-doneChan:
|
||||||
}
|
}
|
||||||
_ = server.Shutdown(context.Background())
|
_ = server.Shutdown(context.Background())
|
||||||
}()
|
}()
|
||||||
@@ -1227,8 +1285,54 @@ func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, c
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
SetOIDCHeaders(req)
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSOOIDCClient) CreateTokenWithAuthCodeAndRegion(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI, region string) (*CreateTokenResponse, error) {
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"code": code,
|
||||||
|
"codeVerifier": codeVerifier,
|
||||||
|
"redirectUri": redirectURI,
|
||||||
|
"grantType": "authorization_code",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
SetOIDCHeaders(req)
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1352,12 +1456,118 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
|
|||||||
|
|
||||||
fmt.Println("\n✓ Authentication successful!")
|
fmt.Println("\n✓ Authentication successful!")
|
||||||
|
|
||||||
// Step 8: Get profile ARN
|
|
||||||
fmt.Println("Fetching profile information...")
|
|
||||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
|
||||||
|
|
||||||
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
||||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf(" Logged in as: %s\n", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: "", // Builder ID has no profile
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "builder-id",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: regResp.ClientID,
|
||||||
|
ClientSecret: regResp.ClientSecret,
|
||||||
|
Email: email,
|
||||||
|
Region: defaultIDCRegion,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSOOIDCClient) LoginWithIDCAuthCode(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Authentication (AWS IDC - Auth Code) ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
if region == "" {
|
||||||
|
region = defaultIDCRegion
|
||||||
|
}
|
||||||
|
|
||||||
|
codeVerifier, codeChallenge, err := generatePKCEForAuthCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := generateStateForAuthCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\nStarting callback server...")
|
||||||
|
redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Callback server started, redirect URI: %s", redirectURI)
|
||||||
|
|
||||||
|
fmt.Println("Registering client...")
|
||||||
|
regResp, err := c.RegisterClientForAuthCodeWithIDC(ctx, redirectURI, startURL, region)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||||
|
|
||||||
|
endpoint := getOIDCEndpoint(region)
|
||||||
|
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
|
||||||
|
authURL := buildAuthorizationURL(endpoint, regResp.ClientID, redirectURI, scopes, state, codeChallenge)
|
||||||
|
|
||||||
|
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Println(" Opening browser for authentication...")
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf("\n URL: %s\n\n", authURL)
|
||||||
|
|
||||||
|
if c.cfg != nil {
|
||||||
|
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||||
|
} else {
|
||||||
|
browser.SetIncognitoMode(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := browser.OpenURL(authURL); err != nil {
|
||||||
|
log.Warnf("Could not open browser automatically: %v", err)
|
||||||
|
fmt.Println(" ⚠ Could not open browser automatically.")
|
||||||
|
fmt.Println(" Please open the URL above in your browser manually.")
|
||||||
|
} else {
|
||||||
|
fmt.Println(" (Browser opened automatically)")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n Waiting for authorization callback...")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(10 * time.Minute):
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, fmt.Errorf("authorization timed out")
|
||||||
|
case result := <-resultChan:
|
||||||
|
if result.Error != "" {
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, fmt.Errorf("authorization failed: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Authorization received!")
|
||||||
|
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Exchanging code for tokens...")
|
||||||
|
tokenResp, err := c.CreateTokenWithAuthCodeAndRegion(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI, region)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Authentication successful!")
|
||||||
|
|
||||||
|
fmt.Println("Fetching profile information...")
|
||||||
|
profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||||
|
|
||||||
|
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||||
if email != "" {
|
if email != "" {
|
||||||
fmt.Printf(" Logged in as: %s\n", email)
|
fmt.Printf(" Logged in as: %s\n", email)
|
||||||
}
|
}
|
||||||
@@ -1369,12 +1579,25 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
|
|||||||
RefreshToken: tokenResp.RefreshToken,
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
ProfileArn: profileArn,
|
ProfileArn: profileArn,
|
||||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
AuthMethod: "builder-id",
|
AuthMethod: "idc",
|
||||||
Provider: "AWS",
|
Provider: "AWS",
|
||||||
ClientID: regResp.ClientID,
|
ClientID: regResp.ClientID,
|
||||||
ClientSecret: regResp.ClientSecret,
|
ClientSecret: regResp.ClientSecret,
|
||||||
Email: email,
|
Email: email,
|
||||||
Region: defaultIDCRegion,
|
StartURL: startURL,
|
||||||
|
Region: region,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildAuthorizationURL(endpoint, clientID, redirectURI, scopes, state, codeChallenge string) string {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Set("client_id", clientID)
|
||||||
|
params.Set("redirect_uri", redirectURI)
|
||||||
|
params.Set("scopes", scopes)
|
||||||
|
params.Set("state", state)
|
||||||
|
params.Set("code_challenge", codeChallenge)
|
||||||
|
params.Set("code_challenge_method", "S256")
|
||||||
|
return fmt.Sprintf("%s/authorize?%s", endpoint, params.Encode())
|
||||||
|
}
|
||||||
|
|||||||
261
internal/auth/kiro/sso_oidc_test.go
Normal file
261
internal/auth/kiro/sso_oidc_test.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type recordingRoundTripper struct {
|
||||||
|
lastReq *http.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
rt.lastReq = req
|
||||||
|
body := `{"nextToken":null,"profiles":[{"arn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC","profileName":"test"}]}`
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryListAvailableProfiles_UsesClientIDForAccountKey(t *testing.T) {
|
||||||
|
rt := &recordingRoundTripper{}
|
||||||
|
client := &SSOOIDCClient{
|
||||||
|
httpClient: &http.Client{Transport: rt},
|
||||||
|
}
|
||||||
|
|
||||||
|
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "client-id-123", "refresh-token-456")
|
||||||
|
if profileArn == "" {
|
||||||
|
t.Fatal("expected profileArn, got empty result")
|
||||||
|
}
|
||||||
|
|
||||||
|
accountKey := GetAccountKey("client-id-123", "refresh-token-456")
|
||||||
|
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||||
|
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
|
||||||
|
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryListAvailableProfiles_UsesRefreshTokenWhenClientIDMissing(t *testing.T) {
|
||||||
|
rt := &recordingRoundTripper{}
|
||||||
|
client := &SSOOIDCClient{
|
||||||
|
httpClient: &http.Client{Transport: rt},
|
||||||
|
}
|
||||||
|
|
||||||
|
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "", "refresh-token-789")
|
||||||
|
if profileArn == "" {
|
||||||
|
t.Fatal("expected profileArn, got empty result")
|
||||||
|
}
|
||||||
|
|
||||||
|
accountKey := GetAccountKey("", "refresh-token-789")
|
||||||
|
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||||
|
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
|
||||||
|
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterClientForAuthCodeWithIDC(t *testing.T) {
|
||||||
|
var capturedReq struct {
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Headers http.Header
|
||||||
|
Body map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
mockResp := RegisterClientResponse{
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
ClientSecret: "test-client-secret",
|
||||||
|
ClientIDIssuedAt: 1700000000,
|
||||||
|
ClientSecretExpiresAt: 1700086400,
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
capturedReq.Method = r.Method
|
||||||
|
capturedReq.Path = r.URL.Path
|
||||||
|
capturedReq.Headers = r.Header.Clone()
|
||||||
|
|
||||||
|
bodyBytes, _ := io.ReadAll(r.Body)
|
||||||
|
json.Unmarshal(bodyBytes, &capturedReq.Body)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(mockResp)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// Extract host to build a region that resolves to our test server.
|
||||||
|
// Override getOIDCEndpoint by passing region="" and patching the endpoint.
|
||||||
|
// Since getOIDCEndpoint builds "https://oidc.{region}.amazonaws.com", we
|
||||||
|
// instead inject the test server URL directly via a custom HTTP client transport.
|
||||||
|
client := &SSOOIDCClient{
|
||||||
|
httpClient: ts.Client(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to route the request to our test server. Use a transport that rewrites the URL.
|
||||||
|
client.httpClient.Transport = &rewriteTransport{
|
||||||
|
base: ts.Client().Transport,
|
||||||
|
targetURL: ts.URL,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.RegisterClientForAuthCodeWithIDC(
|
||||||
|
context.Background(),
|
||||||
|
"http://127.0.0.1:19877/oauth/callback",
|
||||||
|
"https://my-idc-instance.awsapps.com/start",
|
||||||
|
"us-east-1",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify request method and path
|
||||||
|
if capturedReq.Method != http.MethodPost {
|
||||||
|
t.Errorf("method = %q, want POST", capturedReq.Method)
|
||||||
|
}
|
||||||
|
if capturedReq.Path != "/client/register" {
|
||||||
|
t.Errorf("path = %q, want /client/register", capturedReq.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify headers
|
||||||
|
if ct := capturedReq.Headers.Get("Content-Type"); ct != "application/json" {
|
||||||
|
t.Errorf("Content-Type = %q, want application/json", ct)
|
||||||
|
}
|
||||||
|
ua := capturedReq.Headers.Get("User-Agent")
|
||||||
|
if !strings.Contains(ua, "KiroIDE") {
|
||||||
|
t.Errorf("User-Agent %q does not contain KiroIDE", ua)
|
||||||
|
}
|
||||||
|
if !strings.Contains(ua, "sso-oidc") {
|
||||||
|
t.Errorf("User-Agent %q does not contain sso-oidc", ua)
|
||||||
|
}
|
||||||
|
xua := capturedReq.Headers.Get("X-Amz-User-Agent")
|
||||||
|
if !strings.Contains(xua, "KiroIDE") {
|
||||||
|
t.Errorf("x-amz-user-agent %q does not contain KiroIDE", xua)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify body fields
|
||||||
|
if v, _ := capturedReq.Body["clientName"].(string); v != "Kiro IDE" {
|
||||||
|
t.Errorf("clientName = %q, want %q", v, "Kiro IDE")
|
||||||
|
}
|
||||||
|
if v, _ := capturedReq.Body["clientType"].(string); v != "public" {
|
||||||
|
t.Errorf("clientType = %q, want %q", v, "public")
|
||||||
|
}
|
||||||
|
if v, _ := capturedReq.Body["issuerUrl"].(string); v != "https://my-idc-instance.awsapps.com/start" {
|
||||||
|
t.Errorf("issuerUrl = %q, want %q", v, "https://my-idc-instance.awsapps.com/start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify scopes array
|
||||||
|
scopesRaw, ok := capturedReq.Body["scopes"].([]interface{})
|
||||||
|
if !ok || len(scopesRaw) != 5 {
|
||||||
|
t.Fatalf("scopes: got %v, want 5-element array", capturedReq.Body["scopes"])
|
||||||
|
}
|
||||||
|
expectedScopes := []string{
|
||||||
|
"codewhisperer:completions", "codewhisperer:analysis",
|
||||||
|
"codewhisperer:conversations", "codewhisperer:transformations",
|
||||||
|
"codewhisperer:taskassist",
|
||||||
|
}
|
||||||
|
for i, s := range expectedScopes {
|
||||||
|
if scopesRaw[i].(string) != s {
|
||||||
|
t.Errorf("scopes[%d] = %q, want %q", i, scopesRaw[i], s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify grantTypes
|
||||||
|
grantTypesRaw, ok := capturedReq.Body["grantTypes"].([]interface{})
|
||||||
|
if !ok || len(grantTypesRaw) != 2 {
|
||||||
|
t.Fatalf("grantTypes: got %v, want 2-element array", capturedReq.Body["grantTypes"])
|
||||||
|
}
|
||||||
|
if grantTypesRaw[0].(string) != "authorization_code" || grantTypesRaw[1].(string) != "refresh_token" {
|
||||||
|
t.Errorf("grantTypes = %v, want [authorization_code, refresh_token]", grantTypesRaw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify redirectUris
|
||||||
|
redirectRaw, ok := capturedReq.Body["redirectUris"].([]interface{})
|
||||||
|
if !ok || len(redirectRaw) != 1 {
|
||||||
|
t.Fatalf("redirectUris: got %v, want 1-element array", capturedReq.Body["redirectUris"])
|
||||||
|
}
|
||||||
|
if redirectRaw[0].(string) != "http://127.0.0.1:19877/oauth/callback" {
|
||||||
|
t.Errorf("redirectUris[0] = %q, want %q", redirectRaw[0], "http://127.0.0.1:19877/oauth/callback")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response parsing
|
||||||
|
if resp.ClientID != "test-client-id" {
|
||||||
|
t.Errorf("ClientID = %q, want %q", resp.ClientID, "test-client-id")
|
||||||
|
}
|
||||||
|
if resp.ClientSecret != "test-client-secret" {
|
||||||
|
t.Errorf("ClientSecret = %q, want %q", resp.ClientSecret, "test-client-secret")
|
||||||
|
}
|
||||||
|
if resp.ClientIDIssuedAt != 1700000000 {
|
||||||
|
t.Errorf("ClientIDIssuedAt = %d, want %d", resp.ClientIDIssuedAt, 1700000000)
|
||||||
|
}
|
||||||
|
if resp.ClientSecretExpiresAt != 1700086400 {
|
||||||
|
t.Errorf("ClientSecretExpiresAt = %d, want %d", resp.ClientSecretExpiresAt, 1700086400)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteTransport redirects all requests to the test server URL.
|
||||||
|
type rewriteTransport struct {
|
||||||
|
base http.RoundTripper
|
||||||
|
targetURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
target, _ := url.Parse(t.targetURL)
|
||||||
|
req.URL.Scheme = target.Scheme
|
||||||
|
req.URL.Host = target.Host
|
||||||
|
if t.base != nil {
|
||||||
|
return t.base.RoundTrip(req)
|
||||||
|
}
|
||||||
|
return http.DefaultTransport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthorizationURL(t *testing.T) {
|
||||||
|
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
|
||||||
|
endpoint := "https://oidc.us-east-1.amazonaws.com"
|
||||||
|
redirectURI := "http://127.0.0.1:19877/oauth/callback"
|
||||||
|
|
||||||
|
authURL := buildAuthorizationURL(endpoint, "test-client-id", redirectURI, scopes, "random-state", "test-challenge")
|
||||||
|
|
||||||
|
// Verify colons and commas in scopes are percent-encoded
|
||||||
|
if !strings.Contains(authURL, "codewhisperer%3Acompletions") {
|
||||||
|
t.Errorf("expected colons in scopes to be percent-encoded, got: %s", authURL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(authURL, "completions%2Ccodewhisperer") {
|
||||||
|
t.Errorf("expected commas in scopes to be percent-encoded, got: %s", authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse back and verify all parameters round-trip correctly
|
||||||
|
parsed, err := url.Parse(authURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse auth URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(authURL, endpoint+"/authorize?") {
|
||||||
|
t.Errorf("expected URL to start with %s/authorize?, got: %s", endpoint, authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
q := parsed.Query()
|
||||||
|
checks := map[string]string{
|
||||||
|
"response_type": "code",
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"redirect_uri": redirectURI,
|
||||||
|
"scopes": scopes,
|
||||||
|
"state": "random-state",
|
||||||
|
"code_challenge": "test-challenge",
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
}
|
||||||
|
for key, want := range checks {
|
||||||
|
if got := q.Get(key); got != want {
|
||||||
|
t.Errorf("%s = %q, want %q", key, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,7 +29,7 @@ type KiroTokenStorage struct {
|
|||||||
ClientID string `json:"client_id,omitempty"`
|
ClientID string `json:"client_id,omitempty"`
|
||||||
// ClientSecret is the OAuth client secret (required for token refresh)
|
// ClientSecret is the OAuth client secret (required for token refresh)
|
||||||
ClientSecret string `json:"client_secret,omitempty"`
|
ClientSecret string `json:"client_secret,omitempty"`
|
||||||
// Region is the AWS region
|
// Region is the OIDC region for IDC login and token refresh
|
||||||
Region string `json:"region,omitempty"`
|
Region string `json:"region,omitempty"`
|
||||||
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
||||||
StartURL string `json:"start_url,omitempty"`
|
StartURL string `json:"start_url,omitempty"`
|
||||||
|
|||||||
@@ -200,36 +200,22 @@ func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 解析各字段
|
// 解析各字段
|
||||||
if v, ok := metadata["access_token"].(string); ok {
|
token.AccessToken, _ = metadata["access_token"].(string)
|
||||||
token.AccessToken = v
|
token.RefreshToken, _ = metadata["refresh_token"].(string)
|
||||||
}
|
token.ClientID, _ = metadata["client_id"].(string)
|
||||||
if v, ok := metadata["refresh_token"].(string); ok {
|
token.ClientSecret, _ = metadata["client_secret"].(string)
|
||||||
token.RefreshToken = v
|
token.Region, _ = metadata["region"].(string)
|
||||||
}
|
token.StartURL, _ = metadata["start_url"].(string)
|
||||||
if v, ok := metadata["client_id"].(string); ok {
|
token.Provider, _ = metadata["provider"].(string)
|
||||||
token.ClientID = v
|
|
||||||
}
|
|
||||||
if v, ok := metadata["client_secret"].(string); ok {
|
|
||||||
token.ClientSecret = v
|
|
||||||
}
|
|
||||||
if v, ok := metadata["region"].(string); ok {
|
|
||||||
token.Region = v
|
|
||||||
}
|
|
||||||
if v, ok := metadata["start_url"].(string); ok {
|
|
||||||
token.StartURL = v
|
|
||||||
}
|
|
||||||
if v, ok := metadata["provider"].(string); ok {
|
|
||||||
token.Provider = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析时间字段
|
// 解析时间字段
|
||||||
if v, ok := metadata["expires_at"].(string); ok {
|
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
|
||||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||||
token.ExpiresAt = t
|
token.ExpiresAt = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v, ok := metadata["last_refresh"].(string); ok {
|
if lastRefreshStr, ok := metadata["last_refresh"].(string); ok && lastRefreshStr != "" {
|
||||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
if t, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil {
|
||||||
token.LastVerified = t
|
token.LastVerified = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -51,14 +50,12 @@ type QuotaStatus struct {
|
|||||||
// UsageChecker provides methods for checking token quota usage.
|
// UsageChecker provides methods for checking token quota usage.
|
||||||
type UsageChecker struct {
|
type UsageChecker struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
endpoint string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUsageChecker creates a new UsageChecker instance.
|
// NewUsageChecker creates a new UsageChecker instance.
|
||||||
func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
||||||
return &UsageChecker{
|
return &UsageChecker{
|
||||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||||
endpoint: awsKiroEndpoint,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,7 +63,6 @@ func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
|||||||
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
|
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
|
||||||
return &UsageChecker{
|
return &UsageChecker{
|
||||||
httpClient: client,
|
httpClient: client,
|
||||||
endpoint: awsKiroEndpoint,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,26 +76,23 @@ func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData)
|
|||||||
return nil, fmt.Errorf("access token is empty")
|
return nil, fmt.Errorf("access token is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
payload := map[string]interface{}{
|
queryParams := map[string]string{
|
||||||
"origin": "AI_EDITOR",
|
"origin": "AI_EDITOR",
|
||||||
"profileArn": tokenData.ProfileArn,
|
"profileArn": tokenData.ProfileArn,
|
||||||
"resourceType": "AGENTIC_REQUEST",
|
"resourceType": "AGENTIC_REQUEST",
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(payload)
|
// Use endpoint from profileArn if available
|
||||||
if err != nil {
|
endpoint := GetKiroAPIEndpointFromProfileArn(tokenData.ProfileArn)
|
||||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
|
||||||
req.Header.Set("x-amz-target", targetGetUsage)
|
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
|
||||||
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -206,3 +206,52 @@ func DoKiroImport(cfg *config.Config, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
fmt.Println("Kiro token import successful!")
|
fmt.Println("Kiro token import successful!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DoKiroIDCLogin(cfg *config.Config, options *LoginOptions, startURL, region, flow string) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startURL == "" {
|
||||||
|
log.Errorf("Kiro IDC login requires --kiro-idc-start-url")
|
||||||
|
fmt.Println("\nUsage: --kiro-idc-login --kiro-idc-start-url https://d-xxx.awsapps.com/start")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||||
|
metadata := map[string]string{
|
||||||
|
"start-url": startURL,
|
||||||
|
"region": region,
|
||||||
|
"flow": flow,
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: metadata,
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Kiro IDC authentication failed: %v", err)
|
||||||
|
fmt.Println("\nTroubleshooting:")
|
||||||
|
fmt.Println("1. Make sure your IDC Start URL is correct")
|
||||||
|
fmt.Println("2. Complete the authorization in the browser")
|
||||||
|
fmt.Println("3. If auth code flow fails, try: --kiro-idc-flow device")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
savedPath, err := manager.SaveAuth(record, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to save auth: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("Kiro IDC authentication successful!")
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -27,11 +28,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
geminiCLIVersion = "v1internal"
|
geminiCLIVersion = "v1internal"
|
||||||
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
|
|
||||||
geminiCLIApiClient = "gl-node/22.17.0"
|
|
||||||
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type projectSelectionRequiredError struct{}
|
type projectSelectionRequiredError struct{}
|
||||||
@@ -409,9 +407,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
|
|||||||
return fmt.Errorf("create request: %w", errRequest)
|
return fmt.Errorf("create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
|
||||||
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
@@ -630,7 +626,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||||
@@ -651,7 +647,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
resp, errDo = httpClient.Do(req)
|
resp, errDo = httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||||
|
|||||||
60
internal/cmd/openai_device_login.go
Normal file
60
internal/cmd/openai_device_login.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codexLoginModeMetadataKey = "codex_login_mode"
|
||||||
|
codexLoginModeDevice = "device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||||
|
// existing codex-login OAuth callback flow intact.
|
||||||
|
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
CallbackPort: options.CallbackPort,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
if authErr.Type == codex.ErrPortInUse.Type {
|
||||||
|
os.Exit(codex.ErrPortInUse.Code)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("Codex device authentication successful!")
|
||||||
|
}
|
||||||
@@ -69,6 +69,9 @@ type Config struct {
|
|||||||
|
|
||||||
// RequestRetry defines the retry times when the request failed.
|
// RequestRetry defines the retry times when the request failed.
|
||||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||||
|
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
||||||
|
// Set to 0 or a negative value to keep trying all available credentials (legacy behavior).
|
||||||
|
MaxRetryCredentials int `yaml:"max-retry-credentials" json:"max-retry-credentials"`
|
||||||
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
||||||
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
|
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
|
||||||
|
|
||||||
@@ -87,6 +90,10 @@ type Config struct {
|
|||||||
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
|
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
|
||||||
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
|
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
|
||||||
|
|
||||||
|
// KiroFingerprint defines a global fingerprint configuration for all Kiro requests.
|
||||||
|
// When set, all Kiro requests will use this fixed fingerprint instead of random generation.
|
||||||
|
KiroFingerprint *KiroFingerprintConfig `yaml:"kiro-fingerprint,omitempty" json:"kiro-fingerprint,omitempty"`
|
||||||
|
|
||||||
// KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
|
// KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
|
||||||
// Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
|
// Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
|
||||||
KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
|
KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
|
||||||
@@ -477,6 +484,9 @@ type KiroKey struct {
|
|||||||
// Region is the AWS region (default: us-east-1).
|
// Region is the AWS region (default: us-east-1).
|
||||||
Region string `yaml:"region,omitempty" json:"region,omitempty"`
|
Region string `yaml:"region,omitempty" json:"region,omitempty"`
|
||||||
|
|
||||||
|
// StartURL is the IAM Identity Center (IDC) start URL for SSO login.
|
||||||
|
StartURL string `yaml:"start-url,omitempty" json:"start-url,omitempty"`
|
||||||
|
|
||||||
// ProxyURL optionally overrides the global proxy for this configuration.
|
// ProxyURL optionally overrides the global proxy for this configuration.
|
||||||
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||||
|
|
||||||
@@ -489,6 +499,20 @@ type KiroKey struct {
|
|||||||
PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
|
PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KiroFingerprintConfig defines a global fingerprint configuration for Kiro requests.
|
||||||
|
// When configured, all Kiro requests will use this fixed fingerprint instead of random generation.
|
||||||
|
// Empty fields will fall back to random selection from built-in pools.
|
||||||
|
type KiroFingerprintConfig struct {
|
||||||
|
OIDCSDKVersion string `yaml:"oidc-sdk-version,omitempty" json:"oidc-sdk-version,omitempty"`
|
||||||
|
RuntimeSDKVersion string `yaml:"runtime-sdk-version,omitempty" json:"runtime-sdk-version,omitempty"`
|
||||||
|
StreamingSDKVersion string `yaml:"streaming-sdk-version,omitempty" json:"streaming-sdk-version,omitempty"`
|
||||||
|
OSType string `yaml:"os-type,omitempty" json:"os-type,omitempty"`
|
||||||
|
OSVersion string `yaml:"os-version,omitempty" json:"os-version,omitempty"`
|
||||||
|
NodeVersion string `yaml:"node-version,omitempty" json:"node-version,omitempty"`
|
||||||
|
KiroVersion string `yaml:"kiro-version,omitempty" json:"kiro-version,omitempty"`
|
||||||
|
KiroHash string `yaml:"kiro-hash,omitempty" json:"kiro-hash,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||||
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
||||||
type OpenAICompatibility struct {
|
type OpenAICompatibility struct {
|
||||||
@@ -555,16 +579,6 @@ func LoadConfig(configFile string) (*Config, error) {
|
|||||||
// If optional is true and the file is missing, it returns an empty Config.
|
// If optional is true and the file is missing, it returns an empty Config.
|
||||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||||
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
|
||||||
// Reason: avoid mutating config.yaml during server startup.
|
|
||||||
// Re-enable the block below if automatic startup migration is needed again.
|
|
||||||
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
|
||||||
// // Log warning but don't fail - config loading should still work
|
|
||||||
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
|
||||||
// } else if migrated {
|
|
||||||
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Read the entire configuration file into memory.
|
// Read the entire configuration file into memory.
|
||||||
data, err := os.ReadFile(configFile)
|
data, err := os.ReadFile(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -652,6 +666,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.ErrorLogsMaxFiles = 10
|
cfg.ErrorLogsMaxFiles = 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.MaxRetryCredentials < 0 {
|
||||||
|
cfg.MaxRetryCredentials = 0
|
||||||
|
}
|
||||||
|
|
||||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||||
cfg.SanitizeGeminiKeys()
|
cfg.SanitizeGeminiKeys()
|
||||||
|
|
||||||
@@ -1648,9 +1666,6 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
|
|||||||
srcIdx := findMapKeyIndex(srcRoot, key)
|
srcIdx := findMapKeyIndex(srcRoot, key)
|
||||||
if srcIdx < 0 {
|
if srcIdx < 0 {
|
||||||
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
|
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
|
||||||
//
|
|
||||||
// Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the
|
|
||||||
// oauth-model-alias key is missing, migration will add the default antigravity aliases.
|
|
||||||
// When users delete the last channel from oauth-model-alias via the management API,
|
// When users delete the last channel from oauth-model-alias via the management API,
|
||||||
// we want that deletion to persist across hot reloads and restarts.
|
// we want that deletion to persist across hot reloads and restarts.
|
||||||
if key == "oauth-model-alias" {
|
if key == "oauth-model-alias" {
|
||||||
|
|||||||
61
internal/config/oauth_model_alias_defaults.go
Normal file
61
internal/config/oauth_model_alias_defaults.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// defaultKiroAliases returns default oauth-model-alias entries for Kiro.
|
||||||
|
// These aliases expose standard Claude IDs for Kiro-prefixed upstream models.
|
||||||
|
func defaultKiroAliases() []OAuthModelAlias {
|
||||||
|
return []OAuthModelAlias{
|
||||||
|
// Sonnet 4.6
|
||||||
|
{Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||||
|
// Sonnet 4.5
|
||||||
|
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||||
|
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||||
|
// Sonnet 4
|
||||||
|
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||||
|
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||||
|
// Opus 4.6
|
||||||
|
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||||
|
// Opus 4.5
|
||||||
|
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||||
|
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||||
|
// Haiku 4.5
|
||||||
|
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||||
|
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultGitHubCopilotAliases returns default oauth-model-alias entries for
|
||||||
|
// GitHub Copilot Claude models. It exposes hyphen-style IDs used by clients.
|
||||||
|
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
||||||
|
return []OAuthModelAlias{
|
||||||
|
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
||||||
|
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
||||||
|
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
||||||
|
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
||||||
|
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||||
|
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHubCopilotAliasesFromModels generates oauth-model-alias entries from a dynamic
|
||||||
|
// list of model IDs fetched from the Copilot API. It auto-creates aliases for
|
||||||
|
// models whose ID contains a dot (e.g. "claude-opus-4.6" → "claude-opus-4-6"),
|
||||||
|
// which is the pattern used by Claude models on Copilot.
|
||||||
|
func GitHubCopilotAliasesFromModels(modelIDs []string) []OAuthModelAlias {
|
||||||
|
var aliases []OAuthModelAlias
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
for _, id := range modelIDs {
|
||||||
|
if !strings.Contains(id, ".") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hyphenID := strings.ReplaceAll(id, ".", "-")
|
||||||
|
key := id + "→" + hyphenID
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
aliases = append(aliases, OAuthModelAlias{Name: id, Alias: hyphenID, Fork: true})
|
||||||
|
}
|
||||||
|
return aliases
|
||||||
|
}
|
||||||
@@ -1,314 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
// antigravityModelConversionTable maps old built-in aliases to actual model names
|
|
||||||
// for the antigravity channel during migration.
|
|
||||||
var antigravityModelConversionTable = map[string]string{
|
|
||||||
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
|
|
||||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
|
||||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
|
||||||
"gemini-3-flash-preview": "gemini-3-flash",
|
|
||||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
|
||||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
|
||||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
|
||||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultKiroAliases returns the default oauth-model-alias configuration
|
|
||||||
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
|
|
||||||
// names so that clients like Claude Code can use standard names directly.
|
|
||||||
func defaultKiroAliases() []OAuthModelAlias {
|
|
||||||
return []OAuthModelAlias{
|
|
||||||
// Sonnet 4.5
|
|
||||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
|
||||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
|
||||||
// Sonnet 4
|
|
||||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
|
||||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
|
||||||
// Opus 4.6
|
|
||||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
|
||||||
// Opus 4.5
|
|
||||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
|
||||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
|
||||||
// Haiku 4.5
|
|
||||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
|
||||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
|
|
||||||
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
|
|
||||||
// This keeps compatibility with clients (e.g. Claude Code) that use
|
|
||||||
// Anthropic-style model IDs like "claude-opus-4-6".
|
|
||||||
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
|
||||||
return []OAuthModelAlias{
|
|
||||||
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
|
||||||
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
|
||||||
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
|
||||||
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
|
||||||
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
|
||||||
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
|
||||||
// for the antigravity channel when neither field exists.
|
|
||||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
|
||||||
return []OAuthModelAlias{
|
|
||||||
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
|
|
||||||
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
|
|
||||||
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
|
|
||||||
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
|
|
||||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
|
||||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
|
||||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
|
||||||
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
|
|
||||||
// to oauth-model-alias at startup. Returns true if migration was performed.
|
|
||||||
//
|
|
||||||
// Migration flow:
|
|
||||||
// 1. Check if oauth-model-alias exists -> skip migration
|
|
||||||
// 2. Check if oauth-model-mappings exists -> convert and migrate
|
|
||||||
// - For antigravity channel, convert old built-in aliases to actual model names
|
|
||||||
//
|
|
||||||
// 3. Neither exists -> add default antigravity config
|
|
||||||
func MigrateOAuthModelAlias(configFile string) (bool, error) {
|
|
||||||
data, err := os.ReadFile(configFile)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if len(data) == 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse YAML into node tree to preserve structure
|
|
||||||
var root yaml.Node
|
|
||||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
rootMap := root.Content[0]
|
|
||||||
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if oauth-model-alias already exists
|
|
||||||
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if oauth-model-mappings exists
|
|
||||||
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
|
|
||||||
if oldIdx >= 0 {
|
|
||||||
// Migrate from old field
|
|
||||||
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Neither field exists - add default antigravity config
|
|
||||||
return addDefaultAntigravityConfig(configFile, &root, rootMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
|
|
||||||
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
|
|
||||||
if oldIdx+1 >= len(rootMap.Content) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
oldValue := rootMap.Content[oldIdx+1]
|
|
||||||
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the old aliases
|
|
||||||
oldAliases := parseOldAliasNode(oldValue)
|
|
||||||
if len(oldAliases) == 0 {
|
|
||||||
// Remove the old field and write
|
|
||||||
removeMapKeyByIndex(rootMap, oldIdx)
|
|
||||||
return writeYAMLNode(configFile, root)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert model names for antigravity channel
|
|
||||||
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
|
|
||||||
for channel, entries := range oldAliases {
|
|
||||||
converted := make([]OAuthModelAlias, 0, len(entries))
|
|
||||||
for _, entry := range entries {
|
|
||||||
newEntry := OAuthModelAlias{
|
|
||||||
Name: entry.Name,
|
|
||||||
Alias: entry.Alias,
|
|
||||||
Fork: entry.Fork,
|
|
||||||
}
|
|
||||||
// Convert model names for antigravity channel
|
|
||||||
if strings.EqualFold(channel, "antigravity") {
|
|
||||||
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
|
|
||||||
newEntry.Name = actual
|
|
||||||
}
|
|
||||||
}
|
|
||||||
converted = append(converted, newEntry)
|
|
||||||
}
|
|
||||||
newAliases[channel] = converted
|
|
||||||
}
|
|
||||||
|
|
||||||
// For antigravity channel, supplement missing default aliases
|
|
||||||
if antigravityEntries, exists := newAliases["antigravity"]; exists {
|
|
||||||
// Build a set of already configured model names (upstream names)
|
|
||||||
configuredModels := make(map[string]bool, len(antigravityEntries))
|
|
||||||
for _, entry := range antigravityEntries {
|
|
||||||
configuredModels[entry.Name] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add missing default aliases
|
|
||||||
for _, defaultAlias := range defaultAntigravityAliases() {
|
|
||||||
if !configuredModels[defaultAlias.Name] {
|
|
||||||
antigravityEntries = append(antigravityEntries, defaultAlias)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
newAliases["antigravity"] = antigravityEntries
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build new node
|
|
||||||
newNode := buildOAuthModelAliasNode(newAliases)
|
|
||||||
|
|
||||||
// Replace old key with new key and value
|
|
||||||
rootMap.Content[oldIdx].Value = "oauth-model-alias"
|
|
||||||
rootMap.Content[oldIdx+1] = newNode
|
|
||||||
|
|
||||||
return writeYAMLNode(configFile, root)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addDefaultAntigravityConfig adds the default antigravity configuration
|
|
||||||
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
|
|
||||||
defaults := map[string][]OAuthModelAlias{
|
|
||||||
"antigravity": defaultAntigravityAliases(),
|
|
||||||
}
|
|
||||||
newNode := buildOAuthModelAliasNode(defaults)
|
|
||||||
|
|
||||||
// Add new key-value pair
|
|
||||||
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
|
|
||||||
rootMap.Content = append(rootMap.Content, keyNode, newNode)
|
|
||||||
|
|
||||||
return writeYAMLNode(configFile, root)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseOldAliasNode parses the old oauth-model-mappings node structure
|
|
||||||
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
|
|
||||||
if node == nil || node.Kind != yaml.MappingNode {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
result := make(map[string][]OAuthModelAlias)
|
|
||||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
|
||||||
channelNode := node.Content[i]
|
|
||||||
entriesNode := node.Content[i+1]
|
|
||||||
if channelNode == nil || entriesNode == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
|
|
||||||
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
|
|
||||||
for _, entryNode := range entriesNode.Content {
|
|
||||||
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry := parseAliasEntry(entryNode)
|
|
||||||
if entry.Name != "" && entry.Alias != "" {
|
|
||||||
entries = append(entries, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(entries) > 0 {
|
|
||||||
result[channel] = entries
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseAliasEntry parses a single alias entry node
|
|
||||||
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
|
|
||||||
var entry OAuthModelAlias
|
|
||||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
|
||||||
keyNode := node.Content[i]
|
|
||||||
valNode := node.Content[i+1]
|
|
||||||
if keyNode == nil || valNode == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
|
|
||||||
case "name":
|
|
||||||
entry.Name = strings.TrimSpace(valNode.Value)
|
|
||||||
case "alias":
|
|
||||||
entry.Alias = strings.TrimSpace(valNode.Value)
|
|
||||||
case "fork":
|
|
||||||
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
|
|
||||||
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
|
|
||||||
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
|
||||||
for channel, entries := range aliases {
|
|
||||||
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
|
|
||||||
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
|
|
||||||
for _, entry := range entries {
|
|
||||||
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
|
||||||
entryNode.Content = append(entryNode.Content,
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
|
|
||||||
)
|
|
||||||
if entry.Fork {
|
|
||||||
entryNode.Content = append(entryNode.Content,
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
entriesNode.Content = append(entriesNode.Content, entryNode)
|
|
||||||
}
|
|
||||||
node.Content = append(node.Content, channelNode, entriesNode)
|
|
||||||
}
|
|
||||||
return node
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
|
|
||||||
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
|
|
||||||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeYAMLNode writes the YAML node tree back to file
|
|
||||||
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
|
|
||||||
f, err := os.Create(configFile)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
enc := yaml.NewEncoder(f)
|
|
||||||
enc.SetIndent(2)
|
|
||||||
if err := enc.Encode(root); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if err := enc.Close(); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
@@ -1,245 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
content := `oauth-model-alias:
|
|
||||||
gemini-cli:
|
|
||||||
- name: "gemini-2.5-pro"
|
|
||||||
alias: "g2.5p"
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if migrated {
|
|
||||||
t.Fatal("expected no migration when oauth-model-alias already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify file unchanged
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
|
||||||
t.Fatal("file should still contain oauth-model-alias")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
content := `oauth-model-mappings:
|
|
||||||
gemini-cli:
|
|
||||||
- name: "gemini-2.5-pro"
|
|
||||||
alias: "g2.5p"
|
|
||||||
fork: true
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !migrated {
|
|
||||||
t.Fatal("expected migration to occur")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify new field exists and old field removed
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
if strings.Contains(string(data), "oauth-model-mappings:") {
|
|
||||||
t.Fatal("old field should be removed")
|
|
||||||
}
|
|
||||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
|
||||||
t.Fatal("new field should exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse and verify structure
|
|
||||||
var root yaml.Node
|
|
||||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
// Use old model names that should be converted
|
|
||||||
content := `oauth-model-mappings:
|
|
||||||
antigravity:
|
|
||||||
- name: "gemini-2.5-computer-use-preview-10-2025"
|
|
||||||
alias: "computer-use"
|
|
||||||
- name: "gemini-3-pro-preview"
|
|
||||||
alias: "g3p"
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !migrated {
|
|
||||||
t.Fatal("expected migration to occur")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify model names were converted
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
content = string(data)
|
|
||||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
|
||||||
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "gemini-3-pro-high") {
|
|
||||||
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify missing default aliases were supplemented
|
|
||||||
if !strings.Contains(content, "gemini-3-pro-image") {
|
|
||||||
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "gemini-3-flash") {
|
|
||||||
t.Fatal("expected missing default alias gemini-3-flash to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "claude-sonnet-4-5") {
|
|
||||||
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
|
|
||||||
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
|
||||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "claude-opus-4-6-thinking") {
|
|
||||||
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
content := `debug: true
|
|
||||||
port: 8080
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !migrated {
|
|
||||||
t.Fatal("expected migration to add default config")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify default antigravity config was added
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
content = string(data)
|
|
||||||
if !strings.Contains(content, "oauth-model-alias:") {
|
|
||||||
t.Fatal("expected oauth-model-alias to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "antigravity:") {
|
|
||||||
t.Fatal("expected antigravity channel to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
|
||||||
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
content := `debug: true
|
|
||||||
port: 8080
|
|
||||||
oauth-model-mappings:
|
|
||||||
gemini-cli:
|
|
||||||
- name: "test"
|
|
||||||
alias: "t"
|
|
||||||
api-keys:
|
|
||||||
- "key1"
|
|
||||||
- "key2"
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !migrated {
|
|
||||||
t.Fatal("expected migration to occur")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify other config preserved
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
content = string(data)
|
|
||||||
if !strings.Contains(content, "debug: true") {
|
|
||||||
t.Fatal("expected debug field to be preserved")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "port: 8080") {
|
|
||||||
t.Fatal("expected port field to be preserved")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "api-keys:") {
|
|
||||||
t.Fatal("expected api-keys field to be preserved")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error for nonexistent file: %v", err)
|
|
||||||
}
|
|
||||||
if migrated {
|
|
||||||
t.Fatal("expected no migration for nonexistent file")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if migrated {
|
|
||||||
t.Fatal("expected no migration for empty file")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -34,6 +34,9 @@ type VertexCompatKey struct {
|
|||||||
|
|
||||||
// Models defines the model configurations including aliases for routing.
|
// Models defines the model configurations including aliases for routing.
|
||||||
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||||
|
|
||||||
|
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||||
|
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
|
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
|
||||||
@@ -74,6 +77,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
|||||||
}
|
}
|
||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||||
|
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
|
|
||||||
// Sanitize models: remove entries without valid alias
|
// Sanitize models: remove entries without valid alias
|
||||||
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
|
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude.","cache_control":{"type":"ephemeral"}}]
|
[{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}]
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package misc
|
package misc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
|||||||
func LogCredentialSeparator() {
|
func LogCredentialSeparator() {
|
||||||
log.Debug(credentialSeparator)
|
log.Debug(credentialSeparator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||||
|
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||||
|
var data map[string]any
|
||||||
|
|
||||||
|
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||||
|
if srcMap, ok := source.(map[string]any); ok {
|
||||||
|
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||||
|
for k, v := range srcMap {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||||
|
temp, err := json.Marshal(source)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(temp, &data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge extra metadata
|
||||||
|
if metadata != nil {
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]any)
|
||||||
|
}
|
||||||
|
for k, v := range metadata {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,10 +4,98 @@
|
|||||||
package misc
|
package misc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
|
||||||
|
GeminiCLIVersion = "0.31.0"
|
||||||
|
|
||||||
|
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
|
||||||
|
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
// geminiCLIOS maps Go runtime OS names to the Node.js-style platform strings used by Gemini CLI.
|
||||||
|
func geminiCLIOS() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
return "win32"
|
||||||
|
default:
|
||||||
|
return runtime.GOOS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// geminiCLIArch maps Go runtime architecture names to the Node.js-style arch strings used by Gemini CLI.
|
||||||
|
func geminiCLIArch() string {
|
||||||
|
switch runtime.GOARCH {
|
||||||
|
case "amd64":
|
||||||
|
return "x64"
|
||||||
|
case "386":
|
||||||
|
return "x86"
|
||||||
|
default:
|
||||||
|
return runtime.GOARCH
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiCLIUserAgent returns a User-Agent string that matches the Gemini CLI format.
|
||||||
|
// The model parameter is included in the UA; pass "" or "unknown" when the model is not applicable.
|
||||||
|
func GeminiCLIUserAgent(model string) string {
|
||||||
|
if model == "" {
|
||||||
|
model = "unknown"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
|
||||||
|
// proxy infrastructure, client identity, or browser fingerprints from an
|
||||||
|
// outgoing request. This ensures requests to upstream services look like they
|
||||||
|
// originate directly from a native client rather than a third-party client
|
||||||
|
// behind a reverse proxy.
|
||||||
|
func ScrubProxyAndFingerprintHeaders(req *http.Request) {
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Proxy tracing headers ---
|
||||||
|
req.Header.Del("X-Forwarded-For")
|
||||||
|
req.Header.Del("X-Forwarded-Host")
|
||||||
|
req.Header.Del("X-Forwarded-Proto")
|
||||||
|
req.Header.Del("X-Forwarded-Port")
|
||||||
|
req.Header.Del("X-Real-IP")
|
||||||
|
req.Header.Del("Forwarded")
|
||||||
|
req.Header.Del("Via")
|
||||||
|
|
||||||
|
// --- Client identity headers ---
|
||||||
|
req.Header.Del("X-Title")
|
||||||
|
req.Header.Del("X-Stainless-Lang")
|
||||||
|
req.Header.Del("X-Stainless-Package-Version")
|
||||||
|
req.Header.Del("X-Stainless-Os")
|
||||||
|
req.Header.Del("X-Stainless-Arch")
|
||||||
|
req.Header.Del("X-Stainless-Runtime")
|
||||||
|
req.Header.Del("X-Stainless-Runtime-Version")
|
||||||
|
req.Header.Del("Http-Referer")
|
||||||
|
req.Header.Del("Referer")
|
||||||
|
|
||||||
|
// --- Browser / Chromium fingerprint headers ---
|
||||||
|
// These are sent by Electron-based clients (e.g. CherryStudio) using the
|
||||||
|
// Fetch API, but NOT by Node.js https module (which Antigravity uses).
|
||||||
|
req.Header.Del("Sec-Ch-Ua")
|
||||||
|
req.Header.Del("Sec-Ch-Ua-Mobile")
|
||||||
|
req.Header.Del("Sec-Ch-Ua-Platform")
|
||||||
|
req.Header.Del("Sec-Fetch-Mode")
|
||||||
|
req.Header.Del("Sec-Fetch-Site")
|
||||||
|
req.Header.Del("Sec-Fetch-Dest")
|
||||||
|
req.Header.Del("Priority")
|
||||||
|
|
||||||
|
// --- Encoding negotiation ---
|
||||||
|
// Antigravity (Node.js) sends "gzip, deflate, br" by default;
|
||||||
|
// Electron-based clients may add "zstd" which is a fingerprint mismatch.
|
||||||
|
req.Header.Del("Accept-Encoding")
|
||||||
|
}
|
||||||
|
|
||||||
// EnsureHeader ensures that a header exists in the target header map by checking
|
// EnsureHeader ensures that a header exists in the target header map by checking
|
||||||
// multiple sources in order of priority: source headers, existing target headers,
|
// multiple sources in order of priority: source headers, existing target headers,
|
||||||
// and finally the default value. It only sets the header if it's not already present
|
// and finally the default value. It only sets the header if it's not already present
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import (
|
|||||||
// - kiro
|
// - kiro
|
||||||
// - kilo
|
// - kilo
|
||||||
// - github-copilot
|
// - github-copilot
|
||||||
// - kiro
|
|
||||||
// - amazonq
|
// - amazonq
|
||||||
// - antigravity (returns static overrides only)
|
// - antigravity (returns static overrides only)
|
||||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||||
@@ -152,6 +151,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "OpenAI GPT-4.1 via GitHub Copilot",
|
Description: "OpenAI GPT-4.1 via GitHub Copilot",
|
||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,6 +166,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: entry.Description,
|
Description: entry.Description,
|
||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -696,6 +697,42 @@ func GetKiroModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-deepseek-3-2-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro DeepSeek 3.2 (Agentic)",
|
||||||
|
Description: "DeepSeek 3.2 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-minimax-m2-1-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro MiniMax M2.1 (Agentic)",
|
||||||
|
Description: "MiniMax M2.1 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-qwen3-coder-next-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Qwen3 Coder Next (Agentic)",
|
||||||
|
Description: "Qwen3 Coder Next optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ func GetClaudeModels() []*ModelInfo {
|
|||||||
DisplayName: "Claude 4.6 Sonnet",
|
DisplayName: "Claude 4.6 Sonnet",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4-6",
|
ID: "claude-opus-4-6",
|
||||||
@@ -49,7 +49,7 @@ func GetClaudeModels() []*ModelInfo {
|
|||||||
Description: "Premium model combining maximum intelligence with practical performance",
|
Description: "Premium model combining maximum intelligence with practical performance",
|
||||||
ContextLength: 1000000,
|
ContextLength: 1000000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high", "max"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4-6",
|
ID: "claude-sonnet-4-6",
|
||||||
@@ -839,6 +839,20 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.4",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1772668800,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5.4",
|
||||||
|
DisplayName: "GPT 5.4",
|
||||||
|
Description: "Stable version of GPT 5.4 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
|
ContextLength: 1_050_000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -916,19 +930,12 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
Created int64
|
Created int64
|
||||||
Thinking *ThinkingSupport
|
Thinking *ThinkingSupport
|
||||||
}{
|
}{
|
||||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
|
||||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
|
||||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
|
||||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
|
||||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
|
||||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
||||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||||
@@ -937,11 +944,7 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
|
||||||
}
|
}
|
||||||
models := make([]*ModelInfo, 0, len(entries))
|
models := make([]*ModelInfo, 0, len(entries))
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
@@ -970,21 +973,17 @@ type AntigravityModelConfig struct {
|
|||||||
// Keys use upstream model names returned by the Antigravity models endpoint.
|
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||||
return map[string]*AntigravityModelConfig{
|
return map[string]*AntigravityModelConfig{
|
||||||
// "rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3.1-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
|
||||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
"claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"gpt-oss-120b-medium": {},
|
||||||
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
|
||||||
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
|
||||||
"gpt-oss-120b-medium": {},
|
|
||||||
"tab_flash_lite_preview": {},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ type ModelInfo struct {
|
|||||||
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||||
// SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
|
// SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
|
||||||
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
|
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
|
||||||
|
// SupportedInputModalities lists supported input modalities (e.g., TEXT, IMAGE, VIDEO, AUDIO)
|
||||||
|
SupportedInputModalities []string `json:"supportedInputModalities,omitempty"`
|
||||||
|
// SupportedOutputModalities lists supported output modalities (e.g., TEXT, IMAGE)
|
||||||
|
SupportedOutputModalities []string `json:"supportedOutputModalities,omitempty"`
|
||||||
|
|
||||||
// Thinking holds provider-specific reasoning/thinking budget capabilities.
|
// Thinking holds provider-specific reasoning/thinking budget capabilities.
|
||||||
// This is optional and currently used for Gemini thinking budget normalization.
|
// This is optional and currently used for Gemini thinking budget normalization.
|
||||||
@@ -501,8 +505,11 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
|||||||
if len(model.SupportedParameters) > 0 {
|
if len(model.SupportedParameters) > 0 {
|
||||||
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||||
}
|
}
|
||||||
if len(model.SupportedEndpoints) > 0 {
|
if len(model.SupportedInputModalities) > 0 {
|
||||||
copyModel.SupportedEndpoints = append([]string(nil), model.SupportedEndpoints...)
|
copyModel.SupportedInputModalities = append([]string(nil), model.SupportedInputModalities...)
|
||||||
|
}
|
||||||
|
if len(model.SupportedOutputModalities) > 0 {
|
||||||
|
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
||||||
}
|
}
|
||||||
return ©Model
|
return ©Model
|
||||||
}
|
}
|
||||||
@@ -1089,6 +1096,12 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
if len(model.SupportedGenerationMethods) > 0 {
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
||||||
}
|
}
|
||||||
|
if len(model.SupportedInputModalities) > 0 {
|
||||||
|
result["supportedInputModalities"] = model.SupportedInputModalities
|
||||||
|
}
|
||||||
|
if len(model.SupportedOutputModalities) > 0 {
|
||||||
|
result["supportedOutputModalities"] = model.SupportedOutputModalities
|
||||||
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -45,17 +46,87 @@ const (
|
|||||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
antigravityClientSecret = "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"
|
antigravityClientSecret = "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"
|
||||||
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
|
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
||||||
antigravityAuthType = "antigravity"
|
antigravityAuthType = "antigravity"
|
||||||
refreshSkew = 3000 * time.Second
|
refreshSkew = 3000 * time.Second
|
||||||
systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
randSourceMutex sync.Mutex
|
randSourceMutex sync.Mutex
|
||||||
|
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||||
|
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||||
|
antigravityPrimaryModelsCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
models []*registry.ModelInfo
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*registry.ModelInfo, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, cloneAntigravityModelInfo(model))
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||||
|
if model == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *model
|
||||||
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
|
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||||
|
}
|
||||||
|
if len(model.SupportedParameters) > 0 {
|
||||||
|
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||||
|
}
|
||||||
|
if model.Thinking != nil {
|
||||||
|
thinkingClone := *model.Thinking
|
||||||
|
if len(model.Thinking.Levels) > 0 {
|
||||||
|
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||||
|
}
|
||||||
|
clone.Thinking = &thinkingClone
|
||||||
|
}
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||||
|
cloned := cloneAntigravityModels(models)
|
||||||
|
if len(cloned) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
antigravityPrimaryModelsCache.mu.Lock()
|
||||||
|
antigravityPrimaryModelsCache.models = cloned
|
||||||
|
antigravityPrimaryModelsCache.mu.Unlock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||||
|
antigravityPrimaryModelsCache.mu.RLock()
|
||||||
|
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||||
|
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||||
|
models := loadAntigravityPrimaryModels()
|
||||||
|
if len(models) > 0 {
|
||||||
|
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||||
type AntigravityExecutor struct {
|
type AntigravityExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -72,6 +143,62 @@ func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
|
|||||||
return &AntigravityExecutor{cfg: cfg}
|
return &AntigravityExecutor{cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// antigravityTransport is a singleton HTTP/1.1 transport shared by all Antigravity requests.
|
||||||
|
// It is initialized once via antigravityTransportOnce to avoid leaking a new connection pool
|
||||||
|
// (and the goroutines managing it) on every request.
|
||||||
|
var (
|
||||||
|
antigravityTransport *http.Transport
|
||||||
|
antigravityTransportOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
func cloneTransportWithHTTP11(base *http.Transport) *http.Transport {
|
||||||
|
if base == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
clone := base.Clone()
|
||||||
|
clone.ForceAttemptHTTP2 = false
|
||||||
|
// Wipe TLSNextProto to prevent implicit HTTP/2 upgrade.
|
||||||
|
clone.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
|
||||||
|
if clone.TLSClientConfig == nil {
|
||||||
|
clone.TLSClientConfig = &tls.Config{}
|
||||||
|
} else {
|
||||||
|
clone.TLSClientConfig = clone.TLSClientConfig.Clone()
|
||||||
|
}
|
||||||
|
// Actively advertise only HTTP/1.1 in the ALPN handshake.
|
||||||
|
clone.TLSClientConfig.NextProtos = []string{"http/1.1"}
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
// initAntigravityTransport creates the shared HTTP/1.1 transport exactly once.
|
||||||
|
func initAntigravityTransport() {
|
||||||
|
base, ok := http.DefaultTransport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
base = &http.Transport{}
|
||||||
|
}
|
||||||
|
antigravityTransport = cloneTransportWithHTTP11(base)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAntigravityHTTPClient creates an HTTP client specifically for Antigravity,
|
||||||
|
// enforcing HTTP/1.1 by disabling HTTP/2 to perfectly mimic Node.js https defaults.
|
||||||
|
// The underlying Transport is a singleton to avoid leaking connection pools.
|
||||||
|
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||||
|
antigravityTransportOnce.Do(initAntigravityTransport)
|
||||||
|
|
||||||
|
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||||
|
// If no transport is set, use the shared HTTP/1.1 transport.
|
||||||
|
if client.Transport == nil {
|
||||||
|
client.Transport = antigravityTransport
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve proxy settings from proxy-aware transports while forcing HTTP/1.1.
|
||||||
|
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||||
|
client.Transport = cloneTransportWithHTTP11(transport)
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
// Identifier returns the executor identifier.
|
// Identifier returns the executor identifier.
|
||||||
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
|
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
|
||||||
|
|
||||||
@@ -92,6 +219,8 @@ func (e *AntigravityExecutor) PrepareRequest(req *http.Request, auth *cliproxyau
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HttpRequest injects Antigravity credentials into the request and executes it.
|
// HttpRequest injects Antigravity credentials into the request and executes it.
|
||||||
|
// It uses a whitelist approach: all incoming headers are stripped and only
|
||||||
|
// the minimum set required by the Antigravity protocol is explicitly set.
|
||||||
func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
if req == nil {
|
if req == nil {
|
||||||
return nil, fmt.Errorf("antigravity executor: request is nil")
|
return nil, fmt.Errorf("antigravity executor: request is nil")
|
||||||
@@ -100,10 +229,29 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
|
|||||||
ctx = req.Context()
|
ctx = req.Context()
|
||||||
}
|
}
|
||||||
httpReq := req.WithContext(ctx)
|
httpReq := req.WithContext(ctx)
|
||||||
|
|
||||||
|
// --- Whitelist: save only the headers we need from the original request ---
|
||||||
|
contentType := httpReq.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
// Wipe ALL incoming headers
|
||||||
|
for k := range httpReq.Header {
|
||||||
|
delete(httpReq.Header, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Set only the headers Antigravity actually sends ---
|
||||||
|
if contentType != "" {
|
||||||
|
httpReq.Header.Set("Content-Type", contentType)
|
||||||
|
}
|
||||||
|
// Content-Length is managed automatically by Go's http.Client from the Body
|
||||||
|
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||||
|
httpReq.Close = true // sends Connection: close
|
||||||
|
|
||||||
|
// Inject Authorization: Bearer <token>
|
||||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
||||||
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +263,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
isClaude := strings.Contains(strings.ToLower(baseModel), "claude")
|
isClaude := strings.Contains(strings.ToLower(baseModel), "claude")
|
||||||
|
|
||||||
if isClaude || strings.Contains(baseModel, "gemini-3-pro") {
|
if isClaude || strings.Contains(baseModel, "gemini-3-pro") || strings.Contains(baseModel, "gemini-3.1-flash-image") {
|
||||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,7 +298,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
|
|
||||||
@@ -292,7 +440,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
|
|
||||||
@@ -684,7 +832,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
|
|
||||||
@@ -886,7 +1034,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
payload = deleteJSONField(payload, "request.safetySettings")
|
payload = deleteJSONField(payload, "request.safetySettings")
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -917,10 +1065,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
return cliproxyexecutor.Response{}, errReq
|
return cliproxyexecutor.Response{}, errReq
|
||||||
}
|
}
|
||||||
|
httpReq.Close = true
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||||
httpReq.Header.Set("Accept", "application/json")
|
|
||||||
if host := resolveHost(base); host != "" {
|
if host := resolveHost(base); host != "" {
|
||||||
httpReq.Host = host
|
httpReq.Host = host
|
||||||
}
|
}
|
||||||
@@ -1006,28 +1154,34 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||||
exec := &AntigravityExecutor{cfg: cfg}
|
exec := &AntigravityExecutor{cfg: cfg}
|
||||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||||
if errToken != nil {
|
if errToken != nil || token == "" {
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
}
|
||||||
}
|
|
||||||
if token == "" {
|
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if updatedAuth != nil {
|
if updatedAuth != nil {
|
||||||
auth = updatedAuth
|
auth = updatedAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
modelsURL := baseURL + antigravityModelsPath
|
modelsURL := baseURL + antigravityModelsPath
|
||||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
|
||||||
if errReq != nil {
|
var payload []byte
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
|
if auth != nil && auth.Metadata != nil {
|
||||||
return nil
|
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||||
|
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
payload = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
||||||
|
if errReq != nil {
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
|
}
|
||||||
|
httpReq.Close = true
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||||
@@ -1038,15 +1192,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
@@ -1058,22 +1210,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
|
return fallbackAntigravityPrimaryModels()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
|
if idx+1 < len(baseURLs) {
|
||||||
return nil
|
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
result := gjson.GetBytes(bodyBytes, "models")
|
result := gjson.GetBytes(bodyBytes, "models")
|
||||||
if !result.Exists() {
|
if !result.Exists() {
|
||||||
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
|
if idx+1 < len(baseURLs) {
|
||||||
return nil
|
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
@@ -1085,7 +1242,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch modelID {
|
switch modelID {
|
||||||
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
|
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
modelCfg := modelConfig[modelID]
|
modelCfg := modelConfig[modelID]
|
||||||
@@ -1107,6 +1264,29 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
OwnedBy: antigravityAuthType,
|
OwnedBy: antigravityAuthType,
|
||||||
Type: antigravityAuthType,
|
Type: antigravityAuthType,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build input modalities from upstream capability flags.
|
||||||
|
inputModalities := []string{"TEXT"}
|
||||||
|
if modelData.Get("supportsImages").Bool() {
|
||||||
|
inputModalities = append(inputModalities, "IMAGE")
|
||||||
|
}
|
||||||
|
if modelData.Get("supportsVideo").Bool() {
|
||||||
|
inputModalities = append(inputModalities, "VIDEO")
|
||||||
|
}
|
||||||
|
modelInfo.SupportedInputModalities = inputModalities
|
||||||
|
modelInfo.SupportedOutputModalities = []string{"TEXT"}
|
||||||
|
|
||||||
|
// Token limits from upstream.
|
||||||
|
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||||
|
modelInfo.InputTokenLimit = int(maxTok)
|
||||||
|
}
|
||||||
|
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||||
|
modelInfo.OutputTokenLimit = int(maxOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Supported generation methods (Gemini v1beta convention).
|
||||||
|
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
|
||||||
|
|
||||||
// Look up Thinking support from static config using upstream model name.
|
// Look up Thinking support from static config using upstream model name.
|
||||||
if modelCfg != nil {
|
if modelCfg != nil {
|
||||||
if modelCfg.Thinking != nil {
|
if modelCfg.Thinking != nil {
|
||||||
@@ -1118,9 +1298,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
}
|
}
|
||||||
models = append(models, modelInfo)
|
models = append(models, modelInfo)
|
||||||
}
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||||
|
return fallbackAntigravityPrimaryModels()
|
||||||
|
}
|
||||||
|
storeAntigravityPrimaryModels(models)
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
return nil
|
return fallbackAntigravityPrimaryModels()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||||
@@ -1165,10 +1354,11 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
|
|||||||
return auth, errReq
|
return auth, errReq
|
||||||
}
|
}
|
||||||
httpReq.Header.Set("Host", "oauth2.googleapis.com")
|
httpReq.Header.Set("Host", "oauth2.googleapis.com")
|
||||||
httpReq.Header.Set("User-Agent", defaultAntigravityAgent)
|
|
||||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
// Real Antigravity uses Go's default User-Agent for OAuth token refresh
|
||||||
|
httpReq.Header.Set("User-Agent", "Go-http-client/2.0")
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
return auth, errDo
|
return auth, errDo
|
||||||
@@ -1239,7 +1429,7 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
|
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
|
||||||
if errFetch != nil {
|
if errFetch != nil {
|
||||||
return errFetch
|
return errFetch
|
||||||
@@ -1293,7 +1483,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||||
|
|
||||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high")
|
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro")
|
||||||
payloadStr := string(payload)
|
payloadStr := string(payload)
|
||||||
paths := make([]string, 0)
|
paths := make([]string, 0)
|
||||||
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||||
@@ -1307,18 +1497,18 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if useAntigravitySchema {
|
// if useAntigravitySchema {
|
||||||
systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||||
|
|
||||||
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||||
for _, partResult := range systemInstructionPartsResult.Array() {
|
// for _, partResult := range systemInstructionPartsResult.Array() {
|
||||||
payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
if strings.Contains(modelName, "claude") {
|
if strings.Contains(modelName, "claude") {
|
||||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||||
@@ -1330,14 +1520,10 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
return nil, errReq
|
return nil, errReq
|
||||||
}
|
}
|
||||||
|
httpReq.Close = true
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||||
if stream {
|
|
||||||
httpReq.Header.Set("Accept", "text/event-stream")
|
|
||||||
} else {
|
|
||||||
httpReq.Header.Set("Accept", "application/json")
|
|
||||||
}
|
|
||||||
if host := resolveHost(base); host != "" {
|
if host := resolveHost(base); host != "" {
|
||||||
httpReq.Host = host
|
httpReq.Host = host
|
||||||
}
|
}
|
||||||
@@ -1549,7 +1735,16 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
|||||||
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||||
template, _ = sjson.Set(template, "requestType", "agent")
|
|
||||||
|
isImageModel := strings.Contains(modelName, "image")
|
||||||
|
|
||||||
|
var reqType string
|
||||||
|
if isImageModel {
|
||||||
|
reqType = "image_gen"
|
||||||
|
} else {
|
||||||
|
reqType = "agent"
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "requestType", reqType)
|
||||||
|
|
||||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
@@ -1557,8 +1752,13 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
|||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
|
||||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
if isImageModel {
|
||||||
|
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||||
|
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||||
|
}
|
||||||
|
|
||||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||||
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
||||||
@@ -1572,6 +1772,10 @@ func generateRequestID() string {
|
|||||||
return "agent-" + uuid.NewString()
|
return "agent-" + uuid.NewString()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateImageGenRequestID() string {
|
||||||
|
return fmt.Sprintf("image_gen/%d/%s/12", time.Now().UnixMilli(), uuid.NewString())
|
||||||
|
}
|
||||||
|
|
||||||
func generateSessionID() string {
|
func generateSessionID() string {
|
||||||
randSourceMutex.Lock()
|
randSourceMutex.Lock()
|
||||||
n := randSource.Int63n(9_000_000_000_000_000_000)
|
n := randSource.Int63n(9_000_000_000_000_000_000)
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any
|
|||||||
"properties": {
|
"properties": {
|
||||||
"mode": {
|
"mode": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"deprecated": true,
|
||||||
"enum": ["a", "b"],
|
"enum": ["a", "b"],
|
||||||
"enumTitles": ["A", "B"]
|
"enumTitles": ["A", "B"]
|
||||||
}
|
}
|
||||||
@@ -156,4 +157,7 @@ func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]a
|
|||||||
if _, ok := mode["enumTitles"]; ok {
|
if _, ok := mode["enumTitles"]; ok {
|
||||||
t.Fatalf("enumTitles should be removed from nested schema")
|
t.Fatalf("enumTitles should be removed from nested schema")
|
||||||
}
|
}
|
||||||
|
if _, ok := mode["deprecated"]; ok {
|
||||||
|
t.Fatalf("deprecated should be removed from nested schema")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||||
|
antigravityPrimaryModelsCache.mu.Lock()
|
||||||
|
antigravityPrimaryModelsCache.models = nil
|
||||||
|
antigravityPrimaryModelsCache.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||||
|
resetAntigravityPrimaryModelsCacheForTest()
|
||||||
|
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||||
|
|
||||||
|
seed := []*registry.ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4-5"},
|
||||||
|
{ID: "gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||||
|
t.Fatal("expected non-empty model list to update primary cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||||
|
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||||
|
}
|
||||||
|
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||||
|
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := loadAntigravityPrimaryModels()
|
||||||
|
if len(got) != 2 {
|
||||||
|
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||||
|
}
|
||||||
|
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||||
|
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||||
|
resetAntigravityPrimaryModelsCacheForTest()
|
||||||
|
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||||
|
|
||||||
|
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||||
|
ID: "gpt-5",
|
||||||
|
DisplayName: "GPT-5",
|
||||||
|
SupportedGenerationMethods: []string{"generateContent"},
|
||||||
|
SupportedParameters: []string{"temperature"},
|
||||||
|
Thinking: ®istry.ThinkingSupport{
|
||||||
|
Levels: []string{"high"},
|
||||||
|
},
|
||||||
|
}}); !updated {
|
||||||
|
t.Fatal("expected model cache update")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := loadAntigravityPrimaryModels()
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected one cached model, got %d", len(got))
|
||||||
|
}
|
||||||
|
got[0].ID = "mutated-id"
|
||||||
|
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||||
|
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||||
|
}
|
||||||
|
if len(got[0].SupportedParameters) > 0 {
|
||||||
|
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||||
|
}
|
||||||
|
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||||
|
got[0].Thinking.Levels[0] = "mutated-level"
|
||||||
|
}
|
||||||
|
|
||||||
|
again := loadAntigravityPrimaryModels()
|
||||||
|
if len(again) != 1 {
|
||||||
|
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||||
|
}
|
||||||
|
if again[0].ID != "gpt-5" {
|
||||||
|
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||||
|
}
|
||||||
|
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||||
|
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||||
|
}
|
||||||
|
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||||
|
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||||
|
}
|
||||||
|
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||||
|
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,9 +6,14 @@ import (
|
|||||||
"compress/flate"
|
"compress/flate"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -36,7 +41,9 @@ type ClaudeExecutor struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
const claudeToolPrefix = "proxy_"
|
// claudeToolPrefix is empty to match real Claude Code behavior (no tool name prefix).
|
||||||
|
// Previously "proxy_" was used but this is a detectable fingerprint difference.
|
||||||
|
const claudeToolPrefix = ""
|
||||||
|
|
||||||
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
||||||
|
|
||||||
@@ -130,6 +137,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
body = ensureCacheControl(body)
|
body = ensureCacheControl(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
|
||||||
|
// Cloaking and ensureCacheControl may push the total over 4 when the client
|
||||||
|
// (e.g. Amp CLI) already sends multiple cache_control blocks.
|
||||||
|
body = enforceCacheControlLimit(body, 4)
|
||||||
|
|
||||||
|
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
|
||||||
|
// A 1h-TTL block must not appear after a 5m-TTL block in evaluation order (tools→system→messages).
|
||||||
|
body = normalizeCacheControlTTL(body)
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
@@ -171,11 +187,27 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||||
|
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||||
|
// compression. This keeps error-path behaviour consistent with the success path.
|
||||||
|
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||||
|
if decErr != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||||
|
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||||
|
logWithRequestID(ctx).Warn(msg)
|
||||||
|
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||||
|
}
|
||||||
|
b, readErr := io.ReadAll(errBody)
|
||||||
|
if readErr != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||||
|
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||||
|
logWithRequestID(ctx).Warn(msg)
|
||||||
|
b = []byte(msg)
|
||||||
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := errBody.Close(); errClose != nil {
|
||||||
log.Errorf("response body close error: %v", errClose)
|
log.Errorf("response body close error: %v", errClose)
|
||||||
}
|
}
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -271,6 +303,12 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
body = ensureCacheControl(body)
|
body = ensureCacheControl(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
|
||||||
|
body = enforceCacheControlLimit(body, 4)
|
||||||
|
|
||||||
|
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
|
||||||
|
body = normalizeCacheControlTTL(body)
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
@@ -312,10 +350,26 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||||
|
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||||
|
// compression. This keeps error-path behaviour consistent with the success path.
|
||||||
|
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||||
|
if decErr != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||||
|
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||||
|
logWithRequestID(ctx).Warn(msg)
|
||||||
|
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||||
|
}
|
||||||
|
b, readErr := io.ReadAll(errBody)
|
||||||
|
if readErr != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||||
|
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||||
|
logWithRequestID(ctx).Warn(msg)
|
||||||
|
b = []byte(msg)
|
||||||
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := errBody.Close(); errClose != nil {
|
||||||
log.Errorf("response body close error: %v", errClose)
|
log.Errorf("response body close error: %v", errClose)
|
||||||
}
|
}
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
@@ -420,6 +474,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
body = checkSystemInstructions(body)
|
body = checkSystemInstructions(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Keep count_tokens requests compatible with Anthropic cache-control constraints too.
|
||||||
|
body = enforceCacheControlLimit(body, 4)
|
||||||
|
body = normalizeCacheControlTTL(body)
|
||||||
|
|
||||||
// Extract betas from body and convert to header (for count_tokens too)
|
// Extract betas from body and convert to header (for count_tokens too)
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
@@ -459,9 +517,25 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(resp.Body)
|
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||||
|
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||||
|
// compression. This keeps error-path behaviour consistent with the success path.
|
||||||
|
errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
|
||||||
|
if decErr != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||||
|
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||||
|
logWithRequestID(ctx).Warn(msg)
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
|
||||||
|
}
|
||||||
|
b, readErr := io.ReadAll(errBody)
|
||||||
|
if readErr != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||||
|
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||||
|
logWithRequestID(ctx).Warn(msg)
|
||||||
|
b = []byte(msg)
|
||||||
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
if errClose := errBody.Close(); errClose != nil {
|
||||||
log.Errorf("response body close error: %v", errClose)
|
log.Errorf("response body close error: %v", errClose)
|
||||||
}
|
}
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)}
|
||||||
@@ -554,6 +628,12 @@ func disableThinkingIfToolChoiceForced(body []byte) []byte {
|
|||||||
if toolChoiceType == "any" || toolChoiceType == "tool" {
|
if toolChoiceType == "any" || toolChoiceType == "tool" {
|
||||||
// Remove thinking configuration entirely to avoid API error
|
// Remove thinking configuration entirely to avoid API error
|
||||||
body, _ = sjson.DeleteBytes(body, "thinking")
|
body, _ = sjson.DeleteBytes(body, "thinking")
|
||||||
|
// Adaptive thinking may also set output_config.effort; remove it to avoid
|
||||||
|
// leaking thinking controls when tool_choice forces tool use.
|
||||||
|
body, _ = sjson.DeleteBytes(body, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(body, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
body, _ = sjson.DeleteBytes(body, "output_config")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
@@ -576,12 +656,61 @@ func (c *compositeReadCloser) Close() error {
|
|||||||
return firstErr
|
return firstErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// peekableBody wraps a bufio.Reader around the original ReadCloser so that
|
||||||
|
// magic bytes can be inspected without consuming them from the stream.
|
||||||
|
type peekableBody struct {
|
||||||
|
*bufio.Reader
|
||||||
|
closer io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *peekableBody) Close() error {
|
||||||
|
return p.closer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) {
|
func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) {
|
||||||
if body == nil {
|
if body == nil {
|
||||||
return nil, fmt.Errorf("response body is nil")
|
return nil, fmt.Errorf("response body is nil")
|
||||||
}
|
}
|
||||||
if contentEncoding == "" {
|
if contentEncoding == "" {
|
||||||
return body, nil
|
// No Content-Encoding header. Attempt best-effort magic-byte detection to
|
||||||
|
// handle misbehaving upstreams that compress without setting the header.
|
||||||
|
// Only gzip (1f 8b) and zstd (28 b5 2f fd) have reliable magic sequences;
|
||||||
|
// br and deflate have none and are left as-is.
|
||||||
|
// The bufio wrapper preserves unread bytes so callers always see the full
|
||||||
|
// stream regardless of whether decompression was applied.
|
||||||
|
pb := &peekableBody{Reader: bufio.NewReader(body), closer: body}
|
||||||
|
magic, peekErr := pb.Peek(4)
|
||||||
|
if peekErr == nil || (peekErr == io.EOF && len(magic) >= 2) {
|
||||||
|
switch {
|
||||||
|
case len(magic) >= 2 && magic[0] == 0x1f && magic[1] == 0x8b:
|
||||||
|
gzipReader, gzErr := gzip.NewReader(pb)
|
||||||
|
if gzErr != nil {
|
||||||
|
_ = pb.Close()
|
||||||
|
return nil, fmt.Errorf("magic-byte gzip: failed to create reader: %w", gzErr)
|
||||||
|
}
|
||||||
|
return &compositeReadCloser{
|
||||||
|
Reader: gzipReader,
|
||||||
|
closers: []func() error{
|
||||||
|
gzipReader.Close,
|
||||||
|
pb.Close,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
case len(magic) >= 4 && magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd:
|
||||||
|
decoder, zdErr := zstd.NewReader(pb)
|
||||||
|
if zdErr != nil {
|
||||||
|
_ = pb.Close()
|
||||||
|
return nil, fmt.Errorf("magic-byte zstd: failed to create reader: %w", zdErr)
|
||||||
|
}
|
||||||
|
return &compositeReadCloser{
|
||||||
|
Reader: decoder,
|
||||||
|
closers: []func() error{
|
||||||
|
func() error { decoder.Close(); return nil },
|
||||||
|
pb.Close,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pb, nil
|
||||||
}
|
}
|
||||||
encodings := strings.Split(contentEncoding, ",")
|
encodings := strings.Split(contentEncoding, ",")
|
||||||
for _, raw := range encodings {
|
for _, raw := range encodings {
|
||||||
@@ -696,23 +825,29 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
ginHeaders = ginCtx.Request.Header
|
ginHeaders = ginCtx.Request.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
promptCachingBeta := "prompt-caching-2024-07-31"
|
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14," + promptCachingBeta
|
|
||||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||||
baseBetas = val
|
baseBetas = val
|
||||||
if !strings.Contains(val, "oauth") {
|
if !strings.Contains(val, "oauth") {
|
||||||
baseBetas += ",oauth-2025-04-20"
|
baseBetas += ",oauth-2025-04-20"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !strings.Contains(baseBetas, promptCachingBeta) {
|
|
||||||
baseBetas += "," + promptCachingBeta
|
hasClaude1MHeader := false
|
||||||
|
if ginHeaders != nil {
|
||||||
|
if _, ok := ginHeaders[textproto.CanonicalMIMEHeaderKey("X-CPA-CLAUDE-1M")]; ok {
|
||||||
|
hasClaude1MHeader = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge extra betas from request body
|
// Merge extra betas from request body and request flags.
|
||||||
if len(extraBetas) > 0 {
|
if len(extraBetas) > 0 || hasClaude1MHeader {
|
||||||
existingSet := make(map[string]bool)
|
existingSet := make(map[string]bool)
|
||||||
for _, b := range strings.Split(baseBetas, ",") {
|
for _, b := range strings.Split(baseBetas, ",") {
|
||||||
existingSet[strings.TrimSpace(b)] = true
|
betaName := strings.TrimSpace(b)
|
||||||
|
if betaName != "" {
|
||||||
|
existingSet[betaName] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, beta := range extraBetas {
|
for _, beta := range extraBetas {
|
||||||
beta = strings.TrimSpace(beta)
|
beta = strings.TrimSpace(beta)
|
||||||
@@ -721,14 +856,16 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
existingSet[beta] = true
|
existingSet[beta] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if hasClaude1MHeader && !existingSet["context-1m-2025-08-07"] {
|
||||||
|
baseBetas += ",context-1m-2025-08-07"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
r.Header.Set("Anthropic-Beta", baseBetas)
|
r.Header.Set("Anthropic-Beta", baseBetas)
|
||||||
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||||
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
|
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||||
@@ -737,13 +874,28 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
|
// For User-Agent, only forward the client's header if it's already a Claude Code client.
|
||||||
|
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
|
||||||
|
// to avoid leaking the real client identity during cloaking.
|
||||||
|
clientUA := ""
|
||||||
|
if ginHeaders != nil {
|
||||||
|
clientUA = ginHeaders.Get("User-Agent")
|
||||||
|
}
|
||||||
|
if isClaudeCodeClient(clientUA) {
|
||||||
|
r.Header.Set("User-Agent", clientUA)
|
||||||
|
} else {
|
||||||
|
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
|
||||||
|
}
|
||||||
r.Header.Set("Connection", "keep-alive")
|
r.Header.Set("Connection", "keep-alive")
|
||||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
|
||||||
if stream {
|
if stream {
|
||||||
r.Header.Set("Accept", "text/event-stream")
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
|
// SSE streams must not be compressed: the downstream scanner reads
|
||||||
|
// line-delimited text and cannot parse compressed bytes. Using
|
||||||
|
// "identity" tells the upstream to send an uncompressed stream.
|
||||||
|
r.Header.Set("Accept-Encoding", "identity")
|
||||||
} else {
|
} else {
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
|
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||||
}
|
}
|
||||||
// Keep OS/Arch mapping dynamic (not configurable).
|
// Keep OS/Arch mapping dynamic (not configurable).
|
||||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||||
@@ -752,6 +904,12 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
attrs = auth.Attributes
|
attrs = auth.Attributes
|
||||||
}
|
}
|
||||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||||
|
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||||
|
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||||
|
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||||
|
if stream {
|
||||||
|
r.Header.Set("Accept-Encoding", "identity")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||||
@@ -771,22 +929,7 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func checkSystemInstructions(payload []byte) []byte {
|
func checkSystemInstructions(payload []byte) []byte {
|
||||||
system := gjson.GetBytes(payload, "system")
|
return checkSystemInstructionsWithMode(payload, false)
|
||||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
|
||||||
if system.IsArray() {
|
|
||||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
|
||||||
system.ForEach(func(_, part gjson.Result) bool {
|
|
||||||
if part.Get("type").String() == "text" {
|
|
||||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
|
||||||
}
|
|
||||||
return payload
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isClaudeOAuthToken(apiKey string) bool {
|
func isClaudeOAuthToken(apiKey string) bool {
|
||||||
@@ -1060,33 +1203,73 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
|||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkSystemInstructionsWithMode injects Claude Code system prompt.
|
// generateBillingHeader creates the x-anthropic-billing-header text block that
|
||||||
// In strict mode, it replaces all user system messages.
|
// real Claude Code prepends to every system prompt array.
|
||||||
// In non-strict mode (default), it prepends to existing system messages.
|
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=cli; cch=<hash>;
|
||||||
|
func generateBillingHeader(payload []byte) string {
|
||||||
|
// Generate a deterministic cch hash from the payload content (system + messages + tools).
|
||||||
|
// Real Claude Code uses a 5-char hex hash that varies per request.
|
||||||
|
h := sha256.Sum256(payload)
|
||||||
|
cch := hex.EncodeToString(h[:])[:5]
|
||||||
|
|
||||||
|
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
|
||||||
|
buildBytes := make([]byte, 2)
|
||||||
|
_, _ = rand.Read(buildBytes)
|
||||||
|
buildHash := hex.EncodeToString(buildBytes)[:3]
|
||||||
|
|
||||||
|
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSystemInstructionsWithMode injects Claude Code-style system blocks:
|
||||||
|
//
|
||||||
|
// system[0]: billing header (no cache_control)
|
||||||
|
// system[1]: agent identifier (no cache_control)
|
||||||
|
// system[2..]: user system messages (cache_control added when missing)
|
||||||
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||||
system := gjson.GetBytes(payload, "system")
|
system := gjson.GetBytes(payload, "system")
|
||||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
|
||||||
|
billingText := generateBillingHeader(payload)
|
||||||
|
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
|
||||||
|
// No cache_control on the agent block. It is a cloaking artifact with zero cache
|
||||||
|
// value (the last system block is what actually triggers caching of all system content).
|
||||||
|
// Including any cache_control here creates an intra-system TTL ordering violation
|
||||||
|
// when the client's system blocks use ttl='1h' (prompt-caching-scope-2026-01-05 beta
|
||||||
|
// forbids 1h blocks after 5m blocks, and a no-TTL block defaults to 5m).
|
||||||
|
agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK."}`
|
||||||
|
|
||||||
if strictMode {
|
if strictMode {
|
||||||
// Strict mode: replace all system messages with Claude Code prompt only
|
// Strict mode: billing header + agent identifier only
|
||||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
result := "[" + billingBlock + "," + agentBlock + "]"
|
||||||
|
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non-strict mode (default): prepend Claude Code prompt to existing system messages
|
// Non-strict mode: billing header + agent identifier + user system messages
|
||||||
if system.IsArray() {
|
// Skip if already injected
|
||||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
firstText := gjson.GetBytes(payload, "system.0.text").String()
|
||||||
system.ForEach(func(_, part gjson.Result) bool {
|
if strings.HasPrefix(firstText, "x-anthropic-billing-header:") {
|
||||||
if part.Get("type").String() == "text" {
|
return payload
|
||||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result := "[" + billingBlock + "," + agentBlock
|
||||||
|
if system.IsArray() {
|
||||||
|
system.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("type").String() == "text" {
|
||||||
|
// Add cache_control to user system messages if not present.
|
||||||
|
// Do NOT add ttl — let it inherit the default (5m) to avoid
|
||||||
|
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
|
||||||
|
partJSON := part.Raw
|
||||||
|
if !part.Get("cache_control").Exists() {
|
||||||
|
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
|
||||||
|
}
|
||||||
|
result += "," + partJSON
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
result += "]"
|
||||||
|
|
||||||
|
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1224,6 +1407,313 @@ func countCacheControls(payload []byte) int {
|
|||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parsePayloadObject(payload []byte) (map[string]any, bool) {
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
var root map[string]any
|
||||||
|
if err := json.Unmarshal(payload, &root); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return root, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalPayloadObject(original []byte, root map[string]any) []byte {
|
||||||
|
if root == nil {
|
||||||
|
return original
|
||||||
|
}
|
||||||
|
out, err := json.Marshal(root)
|
||||||
|
if err != nil {
|
||||||
|
return original
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func asObject(v any) (map[string]any, bool) {
|
||||||
|
obj, ok := v.(map[string]any)
|
||||||
|
return obj, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func asArray(v any) ([]any, bool) {
|
||||||
|
arr, ok := v.([]any)
|
||||||
|
return arr, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func countCacheControlsMap(root map[string]any) int {
|
||||||
|
count := 0
|
||||||
|
|
||||||
|
if system, ok := asArray(root["system"]); ok {
|
||||||
|
for _, item := range system {
|
||||||
|
if obj, ok := asObject(item); ok {
|
||||||
|
if _, exists := obj["cache_control"]; exists {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools, ok := asArray(root["tools"]); ok {
|
||||||
|
for _, item := range tools {
|
||||||
|
if obj, ok := asObject(item); ok {
|
||||||
|
if _, exists := obj["cache_control"]; exists {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages, ok := asArray(root["messages"]); ok {
|
||||||
|
for _, msg := range messages {
|
||||||
|
msgObj, ok := asObject(msg)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content, ok := asArray(msgObj["content"])
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, item := range content {
|
||||||
|
if obj, ok := asObject(item); ok {
|
||||||
|
if _, exists := obj["cache_control"]; exists {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
|
||||||
|
ccRaw, exists := obj["cache_control"]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cc, ok := asObject(ccRaw)
|
||||||
|
if !ok {
|
||||||
|
*seen5m = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ttlRaw, ttlExists := cc["ttl"]
|
||||||
|
ttl, ttlIsString := ttlRaw.(string)
|
||||||
|
if !ttlExists || !ttlIsString || ttl != "1h" {
|
||||||
|
*seen5m = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if *seen5m {
|
||||||
|
delete(cc, "ttl")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findLastCacheControlIndex(arr []any) int {
|
||||||
|
last := -1
|
||||||
|
for idx, item := range arr {
|
||||||
|
obj, ok := asObject(item)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := obj["cache_control"]; exists {
|
||||||
|
last = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return last
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
|
||||||
|
for idx, item := range arr {
|
||||||
|
if *excess <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
obj, ok := asObject(item)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
|
||||||
|
delete(obj, "cache_control")
|
||||||
|
*excess--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripAllCacheControl(arr []any, excess *int) {
|
||||||
|
for _, item := range arr {
|
||||||
|
if *excess <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
obj, ok := asObject(item)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := obj["cache_control"]; exists {
|
||||||
|
delete(obj, "cache_control")
|
||||||
|
*excess--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripMessageCacheControl(messages []any, excess *int) {
|
||||||
|
for _, msg := range messages {
|
||||||
|
if *excess <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msgObj, ok := asObject(msg)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content, ok := asArray(msgObj["content"])
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, item := range content {
|
||||||
|
if *excess <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
obj, ok := asObject(item)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := obj["cache_control"]; exists {
|
||||||
|
delete(obj, "cache_control")
|
||||||
|
*excess--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
|
||||||
|
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
|
||||||
|
// appear after a 5m-TTL block anywhere in the evaluation order.
|
||||||
|
//
|
||||||
|
// Anthropic evaluates blocks in order: tools → system (index 0..N) → messages.
|
||||||
|
// Within each section, blocks are evaluated in array order. A 5m (default) block
|
||||||
|
// followed by a 1h block at ANY later position is an error — including within
|
||||||
|
// the same section (e.g. system[1]=5m then system[3]=1h).
|
||||||
|
//
|
||||||
|
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
|
||||||
|
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
|
||||||
|
func normalizeCacheControlTTL(payload []byte) []byte {
|
||||||
|
root, ok := parsePayloadObject(payload)
|
||||||
|
if !ok {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
seen5m := false
|
||||||
|
|
||||||
|
if tools, ok := asArray(root["tools"]); ok {
|
||||||
|
for _, tool := range tools {
|
||||||
|
if obj, ok := asObject(tool); ok {
|
||||||
|
normalizeTTLForBlock(obj, &seen5m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if system, ok := asArray(root["system"]); ok {
|
||||||
|
for _, item := range system {
|
||||||
|
if obj, ok := asObject(item); ok {
|
||||||
|
normalizeTTLForBlock(obj, &seen5m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages, ok := asArray(root["messages"]); ok {
|
||||||
|
for _, msg := range messages {
|
||||||
|
msgObj, ok := asObject(msg)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content, ok := asArray(msgObj["content"])
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, item := range content {
|
||||||
|
if obj, ok := asObject(item); ok {
|
||||||
|
normalizeTTLForBlock(obj, &seen5m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalPayloadObject(payload, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
// enforceCacheControlLimit removes excess cache_control blocks from a payload
|
||||||
|
// so the total does not exceed the Anthropic API limit (currently 4).
|
||||||
|
//
|
||||||
|
// Anthropic evaluates cache breakpoints in order: tools → system → messages.
|
||||||
|
// The most valuable breakpoints are:
|
||||||
|
// 1. Last tool — caches ALL tool definitions
|
||||||
|
// 2. Last system block — caches ALL system content
|
||||||
|
// 3. Recent messages — cache conversation context
|
||||||
|
//
|
||||||
|
// Removal priority (strip lowest-value first):
|
||||||
|
//
|
||||||
|
// Phase 1: system blocks earliest-first, preserving the last one.
|
||||||
|
// Phase 2: tool blocks earliest-first, preserving the last one.
|
||||||
|
// Phase 3: message content blocks earliest-first.
|
||||||
|
// Phase 4: remaining system blocks (last system).
|
||||||
|
// Phase 5: remaining tool blocks (last tool).
|
||||||
|
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
|
||||||
|
root, ok := parsePayloadObject(payload)
|
||||||
|
if !ok {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
total := countCacheControlsMap(root)
|
||||||
|
if total <= maxBlocks {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
excess := total - maxBlocks
|
||||||
|
|
||||||
|
var system []any
|
||||||
|
if arr, ok := asArray(root["system"]); ok {
|
||||||
|
system = arr
|
||||||
|
}
|
||||||
|
var tools []any
|
||||||
|
if arr, ok := asArray(root["tools"]); ok {
|
||||||
|
tools = arr
|
||||||
|
}
|
||||||
|
var messages []any
|
||||||
|
if arr, ok := asArray(root["messages"]); ok {
|
||||||
|
messages = arr
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(system) > 0 {
|
||||||
|
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
|
||||||
|
}
|
||||||
|
if excess <= 0 {
|
||||||
|
return marshalPayloadObject(payload, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
|
||||||
|
}
|
||||||
|
if excess <= 0 {
|
||||||
|
return marshalPayloadObject(payload, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(messages) > 0 {
|
||||||
|
stripMessageCacheControl(messages, &excess)
|
||||||
|
}
|
||||||
|
if excess <= 0 {
|
||||||
|
return marshalPayloadObject(payload, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(system) > 0 {
|
||||||
|
stripAllCacheControl(system, &excess)
|
||||||
|
}
|
||||||
|
if excess <= 0 {
|
||||||
|
return marshalPayloadObject(payload, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
stripAllCacheControl(tools, &excess)
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalPayloadObject(payload, root)
|
||||||
|
}
|
||||||
|
|
||||||
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
|
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
|
||||||
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
|
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
|
||||||
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.
|
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.
|
||||||
|
|||||||
@@ -2,12 +2,15 @@ package executor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -348,3 +351,619 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
|||||||
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],
|
||||||
|
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
|
||||||
|
"messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out := normalizeCacheControlTTL(payload)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" {
|
||||||
|
t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
|
||||||
|
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"tools": [
|
||||||
|
{"name":"t1","cache_control":{"type":"ephemeral"}},
|
||||||
|
{"name":"t2","cache_control":{"type":"ephemeral"}}
|
||||||
|
],
|
||||||
|
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
|
||||||
|
"messages": [
|
||||||
|
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]},
|
||||||
|
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out := enforceCacheControlLimit(payload, 4)
|
||||||
|
|
||||||
|
if got := countCacheControls(out); got != 4 {
|
||||||
|
t.Fatalf("cache_control count = %d, want 4", got)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||||
|
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(out, "tools.1.cache_control").Exists() {
|
||||||
|
t.Fatalf("tools.1.cache_control (last tool) should be preserved")
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() {
|
||||||
|
t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"tools": [
|
||||||
|
{"name":"t1","cache_control":{"type":"ephemeral"}},
|
||||||
|
{"name":"t2","cache_control":{"type":"ephemeral"}},
|
||||||
|
{"name":"t3","cache_control":{"type":"ephemeral"}},
|
||||||
|
{"name":"t4","cache_control":{"type":"ephemeral"}},
|
||||||
|
{"name":"t5","cache_control":{"type":"ephemeral"}}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out := enforceCacheControlLimit(payload, 4)
|
||||||
|
|
||||||
|
if got := countCacheControls(out); got != 4 {
|
||||||
|
t.Fatalf("cache_control count = %d, want 4", got)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||||
|
t.Fatalf("tools.0.cache_control should be removed to satisfy max=4")
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(out, "tools.4.cache_control").Exists() {
|
||||||
|
t.Fatalf("last tool cache_control should be preserved when possible")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) {
|
||||||
|
var seenBody []byte
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
seenBody = bytes.Clone(body)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"input_tokens":42}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
|
||||||
|
payload := []byte(`{
|
||||||
|
"tools": [
|
||||||
|
{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}},
|
||||||
|
{"name":"t2","cache_control":{"type":"ephemeral"}}
|
||||||
|
],
|
||||||
|
"system": [
|
||||||
|
{"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}},
|
||||||
|
{"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}}
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
|
||||||
|
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-haiku-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountTokens error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(seenBody) == 0 {
|
||||||
|
t.Fatal("expected count_tokens request body to be captured")
|
||||||
|
}
|
||||||
|
if got := countCacheControls(seenBody); got > 4 {
|
||||||
|
t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got)
|
||||||
|
}
|
||||||
|
if hasTTLOrderingViolation(seenBody) {
|
||||||
|
t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasTTLOrderingViolation(payload []byte) bool {
|
||||||
|
seen5m := false
|
||||||
|
violates := false
|
||||||
|
|
||||||
|
checkCC := func(cc gjson.Result) {
|
||||||
|
if !cc.Exists() || violates {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ttl := cc.Get("ttl").String()
|
||||||
|
if ttl != "1h" {
|
||||||
|
seen5m = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if seen5m {
|
||||||
|
violates = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := gjson.GetBytes(payload, "tools")
|
||||||
|
if tools.IsArray() {
|
||||||
|
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||||
|
checkCC(tool.Get("cache_control"))
|
||||||
|
return !violates
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
system := gjson.GetBytes(payload, "system")
|
||||||
|
if system.IsArray() {
|
||||||
|
system.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
checkCC(item.Get("cache_control"))
|
||||||
|
return !violates
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := gjson.GetBytes(payload, "messages")
|
||||||
|
if messages.IsArray() {
|
||||||
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
content := msg.Get("content")
|
||||||
|
if content.IsArray() {
|
||||||
|
content.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
checkCC(item.Get("cache_control"))
|
||||||
|
return !violates
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return !violates
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return violates
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_Execute_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||||
|
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||||
|
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_ExecuteStream_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||||
|
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||||
|
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_CountTokens_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||||
|
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||||
|
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testClaudeExecutorInvalidCompressedErrorBody(
|
||||||
|
t *testing.T,
|
||||||
|
invoke func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte("not-a-valid-gzip-stream"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
err := invoke(executor, auth, payload)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "failed to decode error response body") {
|
||||||
|
t.Fatalf("expected decode failure message, got: %v", err)
|
||||||
|
}
|
||||||
|
if statusProvider, ok := err.(interface{ StatusCode() int }); !ok || statusProvider.StatusCode() != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
||||||
|
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
||||||
|
// compressed SSE body that would silently break the line scanner.
|
||||||
|
func TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding(t *testing.T) {
|
||||||
|
var gotEncoding, gotAccept string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||||
|
gotAccept = r.Header.Get("Accept")
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream error: %v", err)
|
||||||
|
}
|
||||||
|
for chunk := range result.Chunks {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotEncoding != "identity" {
|
||||||
|
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "identity")
|
||||||
|
}
|
||||||
|
if gotAccept != "text/event-stream" {
|
||||||
|
t.Errorf("Accept = %q, want %q", gotAccept, "text/event-stream")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding verifies that non-streaming
|
||||||
|
// requests keep the full accept-encoding to allow response compression (which
|
||||||
|
// decodeResponseBody handles correctly).
|
||||||
|
func TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding(t *testing.T) {
|
||||||
|
var gotEncoding, gotAccept string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||||
|
gotAccept = r.Header.Get("Accept")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet-20241022","role":"assistant","content":[{"type":"text","text":"hi"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotEncoding != "gzip, deflate, br, zstd" {
|
||||||
|
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "gzip, deflate, br, zstd")
|
||||||
|
}
|
||||||
|
if gotAccept != "application/json" {
|
||||||
|
t.Errorf("Accept = %q, want %q", gotAccept, "application/json")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded verifies that a streaming
|
||||||
|
// HTTP 200 response with Content-Encoding: gzip is correctly decompressed before
|
||||||
|
// the line scanner runs, so SSE chunks are not silently dropped.
|
||||||
|
func TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := gzip.NewWriter(&buf)
|
||||||
|
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
|
||||||
|
_ = gz.Close()
|
||||||
|
compressedBody := buf.Bytes()
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
|
_, _ = w.Write(compressedBody)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var combined strings.Builder
|
||||||
|
for chunk := range result.Chunks {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
t.Fatalf("chunk error: %v", chunk.Err)
|
||||||
|
}
|
||||||
|
combined.Write(chunk.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
if combined.Len() == 0 {
|
||||||
|
t.Fatal("expected at least one chunk from gzip-encoded SSE body, got none (body was not decompressed)")
|
||||||
|
}
|
||||||
|
if !strings.Contains(combined.String(), "message_stop") {
|
||||||
|
t.Errorf("expected SSE content in chunks, got: %q", combined.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecodeResponseBody_MagicByteGzipNoHeader verifies that decodeResponseBody
|
||||||
|
// detects gzip-compressed content via magic bytes even when Content-Encoding is absent.
|
||||||
|
func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
|
||||||
|
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := gzip.NewWriter(&buf)
|
||||||
|
_, _ = gz.Write([]byte(plaintext))
|
||||||
|
_ = gz.Close()
|
||||||
|
|
||||||
|
rc := io.NopCloser(&buf)
|
||||||
|
decoded, err := decodeResponseBody(rc, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeResponseBody error: %v", err)
|
||||||
|
}
|
||||||
|
defer decoded.Close()
|
||||||
|
|
||||||
|
got, err := io.ReadAll(decoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != plaintext {
|
||||||
|
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
|
||||||
|
// plain text untouched when Content-Encoding is absent and no magic bytes match.
|
||||||
|
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
|
||||||
|
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||||
|
rc := io.NopCloser(strings.NewReader(plaintext))
|
||||||
|
decoded, err := decodeResponseBody(rc, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeResponseBody error: %v", err)
|
||||||
|
}
|
||||||
|
defer decoded.Close()
|
||||||
|
|
||||||
|
got, err := io.ReadAll(decoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != plaintext {
|
||||||
|
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader verifies the full
|
||||||
|
// pipeline: when the upstream returns a gzip-compressed SSE body WITHOUT setting
|
||||||
|
// Content-Encoding (a misbehaving upstream), the magic-byte sniff in
|
||||||
|
// decodeResponseBody still decompresses it, so chunks reach the caller.
|
||||||
|
func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := gzip.NewWriter(&buf)
|
||||||
|
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
|
||||||
|
_ = gz.Close()
|
||||||
|
compressedBody := buf.Bytes()
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||||
|
_, _ = w.Write(compressedBody)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var combined strings.Builder
|
||||||
|
for chunk := range result.Chunks {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
t.Fatalf("chunk error: %v", chunk.Err)
|
||||||
|
}
|
||||||
|
combined.Write(chunk.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
if combined.Len() == 0 {
|
||||||
|
t.Fatal("expected chunks from gzip body without Content-Encoding header, got none (magic-byte sniff failed)")
|
||||||
|
}
|
||||||
|
if !strings.Contains(combined.String(), "message_stop") {
|
||||||
|
t.Errorf("unexpected chunk content: %q", combined.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
|
||||||
|
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
|
||||||
|
// path's enforced identity encoding.
|
||||||
|
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
|
||||||
|
var gotEncoding string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
// Inject Accept-Encoding via the custom header attribute mechanism.
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream error: %v", err)
|
||||||
|
}
|
||||||
|
for chunk := range result.Chunks {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotEncoding != "identity" {
|
||||||
|
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
|
||||||
|
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
|
||||||
|
// Content-Encoding is absent.
|
||||||
|
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
|
||||||
|
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc, err := zstd.NewWriter(&buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("zstd.NewWriter: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = enc.Write([]byte(plaintext))
|
||||||
|
_ = enc.Close()
|
||||||
|
|
||||||
|
rc := io.NopCloser(&buf)
|
||||||
|
decoded, err := decodeResponseBody(rc, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeResponseBody error: %v", err)
|
||||||
|
}
|
||||||
|
defer decoded.Close()
|
||||||
|
|
||||||
|
got, err := io.ReadAll(decoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != plaintext {
|
||||||
|
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
|
||||||
|
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
|
||||||
|
// the Content-Encoding header. This closes the gap left by PR #1771, which only
|
||||||
|
// fixed header-declared compression on the error path.
|
||||||
|
func TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
|
||||||
|
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"test error"}}`
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := gzip.NewWriter(&buf)
|
||||||
|
_, _ = gz.Write([]byte(errJSON))
|
||||||
|
_ = gz.Close()
|
||||||
|
compressedBody := buf.Bytes()
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write(compressedBody)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected an error for 400 response, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "test error") {
|
||||||
|
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader verifies
|
||||||
|
// the same for the streaming executor: 4xx gzip body without Content-Encoding is
|
||||||
|
// decoded and the error message is readable.
|
||||||
|
func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
|
||||||
|
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"stream test error"}}`
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := gzip.NewWriter(&buf)
|
||||||
|
_, _ = gz.Write([]byte(errJSON))
|
||||||
|
_ = gz.Close()
|
||||||
|
compressedBody := buf.Bytes()
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write(compressedBody)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected an error for 400 response, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "stream test error") {
|
||||||
|
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,17 +9,18 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
|
// userIDPattern matches Claude Code format: user_[64-hex]_account_[uuid]_session_[uuid]
|
||||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}_session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||||
|
|
||||||
// generateFakeUserID generates a fake user ID in Claude Code format.
|
// generateFakeUserID generates a fake user ID in Claude Code format.
|
||||||
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
|
// Format: user_[64-hex-chars]_account_[UUID-v4]_session_[UUID-v4]
|
||||||
func generateFakeUserID() string {
|
func generateFakeUserID() string {
|
||||||
hexBytes := make([]byte, 32)
|
hexBytes := make([]byte, 32)
|
||||||
_, _ = rand.Read(hexBytes)
|
_, _ = rand.Read(hexBytes)
|
||||||
hexPart := hex.EncodeToString(hexBytes)
|
hexPart := hex.EncodeToString(hexBytes)
|
||||||
uuidPart := uuid.New().String()
|
accountUUID := uuid.New().String()
|
||||||
return "user_" + hexPart + "_account__session_" + uuidPart
|
sessionUUID := uuid.New().String()
|
||||||
|
return "user_" + hexPart + "_account_" + accountUUID + "_session_" + sessionUUID
|
||||||
}
|
}
|
||||||
|
|
||||||
// isValidUserID checks if a user ID matches Claude Code format.
|
// isValidUserID checks if a user ID matches Claude Code format.
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -358,7 +358,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
@@ -616,6 +616,10 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
|||||||
if promptCacheKey.Exists() {
|
if promptCacheKey.Exists() {
|
||||||
cache.ID = promptCacheKey.String()
|
cache.ID = promptCacheKey.String()
|
||||||
}
|
}
|
||||||
|
} else if from == "openai" {
|
||||||
|
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
|
||||||
|
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cache.ID != "" {
|
if cache.ID != "" {
|
||||||
@@ -673,6 +677,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
|||||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||||
|
err := statusErr{code: statusCode, msg: string(body)}
|
||||||
|
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||||
|
err.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||||
|
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
|
||||||
|
resetAtTime := time.Unix(resetsAt, 0)
|
||||||
|
if resetAtTime.After(now) {
|
||||||
|
retryAfter := resetAtTime.Sub(now)
|
||||||
|
return &retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
|
||||||
|
retryAfter := time.Duration(resetsInSeconds) * time.Second
|
||||||
|
return &retryAfter
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
|
|||||||
64
internal/runtime/executor/codex_executor_cache_test.go
Normal file
64
internal/runtime/executor/codex_executor_cache_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFromAPIKey(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ginCtx.Set("apiKey", "test-api-key")
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||||
|
executor := &CodexExecutor{}
|
||||||
|
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true}`)
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-5.3-codex",
|
||||||
|
Payload: []byte(`{"model":"gpt-5.3-codex"}`),
|
||||||
|
}
|
||||||
|
url := "https://example.com/responses"
|
||||||
|
|
||||||
|
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cacheHelper error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, errRead := io.ReadAll(httpReq.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
t.Fatalf("read request body: %v", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedKey := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String()
|
||||||
|
gotKey := gjson.GetBytes(body, "prompt_cache_key").String()
|
||||||
|
if gotKey != expectedKey {
|
||||||
|
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
||||||
|
}
|
||||||
|
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
|
||||||
|
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
|
||||||
|
}
|
||||||
|
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
||||||
|
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cacheHelper error (second call): %v", err)
|
||||||
|
}
|
||||||
|
body2, errRead2 := io.ReadAll(httpReq2.Body)
|
||||||
|
if errRead2 != nil {
|
||||||
|
t.Fatalf("read request body (second call): %v", errRead2)
|
||||||
|
}
|
||||||
|
gotKey2 := gjson.GetBytes(body2, "prompt_cache_key").String()
|
||||||
|
if gotKey2 != expectedKey {
|
||||||
|
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
65
internal/runtime/executor/codex_executor_retry_test.go
Normal file
65
internal/runtime/executor/codex_executor_retry_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseCodexRetryAfter(t *testing.T) {
|
||||||
|
now := time.Unix(1_700_000_000, 0)
|
||||||
|
|
||||||
|
t.Run("resets_in_seconds", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
|
||||||
|
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||||
|
if retryAfter == nil {
|
||||||
|
t.Fatalf("expected retryAfter, got nil")
|
||||||
|
}
|
||||||
|
if *retryAfter != 123*time.Second {
|
||||||
|
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prefers resets_at", func(t *testing.T) {
|
||||||
|
resetAt := now.Add(5 * time.Minute).Unix()
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
|
||||||
|
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||||
|
if retryAfter == nil {
|
||||||
|
t.Fatalf("expected retryAfter, got nil")
|
||||||
|
}
|
||||||
|
if *retryAfter != 5*time.Minute {
|
||||||
|
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("fallback when resets_at is past", func(t *testing.T) {
|
||||||
|
resetAt := now.Add(-1 * time.Minute).Unix()
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
|
||||||
|
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||||
|
if retryAfter == nil {
|
||||||
|
t.Fatalf("expected retryAfter, got nil")
|
||||||
|
}
|
||||||
|
if *retryAfter != 77*time.Second {
|
||||||
|
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-429 status code", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
|
||||||
|
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
|
||||||
|
t.Fatalf("expected nil for non-429, got %v", *got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non usage_limit_reached error type", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
|
||||||
|
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
|
||||||
|
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func itoa(v int64) string {
|
||||||
|
return strconv.FormatInt(v, 10)
|
||||||
|
}
|
||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
@@ -81,7 +80,7 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
|
|||||||
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
applyGeminiCLIHeaders(req)
|
applyGeminiCLIHeaders(req, "unknown")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,7 +188,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
}
|
}
|
||||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
applyGeminiCLIHeaders(reqHTTP)
|
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||||
reqHTTP.Header.Set("Accept", "application/json")
|
reqHTTP.Header.Set("Accept", "application/json")
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
@@ -334,7 +333,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
applyGeminiCLIHeaders(reqHTTP)
|
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||||
reqHTTP.Header.Set("Accept", "text/event-stream")
|
reqHTTP.Header.Set("Accept", "text/event-stream")
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
@@ -515,7 +514,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
applyGeminiCLIHeaders(reqHTTP)
|
applyGeminiCLIHeaders(reqHTTP, baseModel)
|
||||||
reqHTTP.Header.Set("Accept", "application/json")
|
reqHTTP.Header.Set("Accept", "application/json")
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
@@ -738,21 +737,11 @@ func stringValue(m map[string]any, key string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream.
|
// applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream.
|
||||||
func applyGeminiCLIHeaders(r *http.Request) {
|
// User-Agent is always forced to the GeminiCLI format regardless of the client's value,
|
||||||
var ginHeaders http.Header
|
// so that upstream identifies the request as a native GeminiCLI client.
|
||||||
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
func applyGeminiCLIHeaders(r *http.Request, model string) {
|
||||||
ginHeaders = ginCtx.Request.Header
|
r.Header.Set("User-Agent", misc.GeminiCLIUserAgent(model))
|
||||||
}
|
r.Header.Set("X-Goog-Api-Client", misc.GeminiCLIApiClientHeader)
|
||||||
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "google-api-nodejs-client/9.15.1")
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Goog-Api-Client", "gl-node/22.17.0")
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Client-Metadata", geminiCLIClientMetadata())
|
|
||||||
}
|
|
||||||
|
|
||||||
// geminiCLIClientMetadata returns a compact metadata string required by upstream.
|
|
||||||
func geminiCLIClientMetadata() string {
|
|
||||||
// Keep parity with CLI client defaults
|
|
||||||
return "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cliPreviewFallbackOrder returns preview model candidates for a base model.
|
// cliPreviewFallbackOrder returns preview model candidates for a base model.
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -490,18 +491,46 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
|
|||||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||||
|
|
||||||
initiator := "user"
|
initiator := "user"
|
||||||
if len(body) > 0 {
|
if role := detectLastConversationRole(body); role == "assistant" || role == "tool" {
|
||||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
initiator = "agent"
|
||||||
for _, msg := range messages.Array() {
|
}
|
||||||
role := msg.Get("role").String()
|
r.Header.Set("X-Initiator", initiator)
|
||||||
if role == "assistant" || role == "tool" {
|
}
|
||||||
initiator = "agent"
|
|
||||||
break
|
func detectLastConversationRole(body []byte) string {
|
||||||
}
|
if len(body) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||||
|
arr := messages.Array()
|
||||||
|
for i := len(arr) - 1; i >= 0; i-- {
|
||||||
|
if role := arr[i].Get("role").String(); role != "" {
|
||||||
|
return role
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.Header.Set("X-Initiator", initiator)
|
|
||||||
|
if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() {
|
||||||
|
arr := inputs.Array()
|
||||||
|
for i := len(arr) - 1; i >= 0; i-- {
|
||||||
|
item := arr[i]
|
||||||
|
|
||||||
|
// Most Responses input items carry a top-level role.
|
||||||
|
if role := item.Get("role").String(); role != "" {
|
||||||
|
return role
|
||||||
|
}
|
||||||
|
|
||||||
|
switch item.Get("type").String() {
|
||||||
|
case "function_call", "function_call_arguments":
|
||||||
|
return "assistant"
|
||||||
|
case "function_call_output", "function_call_response", "tool_result":
|
||||||
|
return "tool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectVisionContent checks if the request body contains vision/image content.
|
// detectVisionContent checks if the request body contains vision/image content.
|
||||||
@@ -1236,3 +1265,99 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
func isHTTPSuccess(statusCode int) bool {
|
func isHTTPSuccess(statusCode int) bool {
|
||||||
return statusCode >= 200 && statusCode < 300
|
return statusCode >= 200 && statusCode < 300
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// defaultCopilotContextLength is the default context window for unknown Copilot models.
|
||||||
|
defaultCopilotContextLength = 128000
|
||||||
|
// defaultCopilotMaxCompletionTokens is the default max output tokens for unknown Copilot models.
|
||||||
|
defaultCopilotMaxCompletionTokens = 16384
|
||||||
|
)
|
||||||
|
|
||||||
|
// FetchGitHubCopilotModels dynamically fetches available models from the GitHub Copilot API.
|
||||||
|
// It exchanges the GitHub access token stored in auth.Metadata for a Copilot API token,
|
||||||
|
// then queries the /models endpoint. Falls back to the static registry on any failure.
|
||||||
|
func FetchGitHubCopilotModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||||
|
if auth == nil {
|
||||||
|
log.Debug("github-copilot: auth is nil, using static models")
|
||||||
|
return registry.GetGitHubCopilotModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||||
|
if accessToken == "" {
|
||||||
|
log.Debug("github-copilot: no access_token in auth metadata, using static models")
|
||||||
|
return registry.GetGitHubCopilotModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
copilotAuth := copilotauth.NewCopilotAuth(cfg)
|
||||||
|
|
||||||
|
entries, err := copilotAuth.ListModelsWithGitHubToken(ctx, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("github-copilot: failed to fetch dynamic models: %v, using static models", err)
|
||||||
|
return registry.GetGitHubCopilotModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(entries) == 0 {
|
||||||
|
log.Debug("github-copilot: API returned no models, using static models")
|
||||||
|
return registry.GetGitHubCopilotModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a lookup from the static definitions so we can enrich dynamic entries
|
||||||
|
// with known context lengths, thinking support, etc.
|
||||||
|
staticMap := make(map[string]*registry.ModelInfo)
|
||||||
|
for _, m := range registry.GetGitHubCopilotModels() {
|
||||||
|
staticMap[m.ID] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
models := make([]*registry.ModelInfo, 0, len(entries))
|
||||||
|
seen := make(map[string]struct{}, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.ID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Deduplicate model IDs to avoid incorrect reference counting.
|
||||||
|
if _, dup := seen[entry.ID]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[entry.ID] = struct{}{}
|
||||||
|
|
||||||
|
m := ®istry.ModelInfo{
|
||||||
|
ID: entry.ID,
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.Created > 0 {
|
||||||
|
m.Created = entry.Created
|
||||||
|
}
|
||||||
|
if entry.Name != "" {
|
||||||
|
m.DisplayName = entry.Name
|
||||||
|
} else {
|
||||||
|
m.DisplayName = entry.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge known metadata from the static fallback list
|
||||||
|
if static, ok := staticMap[entry.ID]; ok {
|
||||||
|
if m.DisplayName == entry.ID && static.DisplayName != "" {
|
||||||
|
m.DisplayName = static.DisplayName
|
||||||
|
}
|
||||||
|
m.Description = static.Description
|
||||||
|
m.ContextLength = static.ContextLength
|
||||||
|
m.MaxCompletionTokens = static.MaxCompletionTokens
|
||||||
|
m.SupportedEndpoints = static.SupportedEndpoints
|
||||||
|
m.Thinking = static.Thinking
|
||||||
|
} else {
|
||||||
|
// Sensible defaults for models not in the static list
|
||||||
|
m.Description = entry.ID + " via GitHub Copilot"
|
||||||
|
m.ContextLength = defaultCopilotContextLength
|
||||||
|
m.MaxCompletionTokens = defaultCopilotMaxCompletionTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
models = append(models, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("github-copilot: fetched %d models from API", len(models))
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|||||||
@@ -262,15 +262,15 @@ func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
|
func TestApplyHeaders_XInitiator_UserWhenLastRoleIsUser(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
e := &GitHubCopilotExecutor{}
|
e := &GitHubCopilotExecutor{}
|
||||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
// Claude Code typical flow: last message is user (tool result), but has assistant in history
|
// Last role governs the initiator decision.
|
||||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
||||||
e.applyHeaders(req, "token", body)
|
e.applyHeaders(req, "token", body)
|
||||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||||
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
|
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,6 +285,39 @@ func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyHeaders_XInitiator_InputArrayLastAssistantMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"Hi"}]},{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello"}]}]}`)
|
||||||
|
e.applyHeaders(req, "token", body)
|
||||||
|
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||||
|
t.Fatalf("X-Initiator = %q, want agent (last role is assistant)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyHeaders_XInitiator_InputArrayLastUserMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
|
||||||
|
e.applyHeaders(req, "token", body)
|
||||||
|
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||||
|
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyHeaders_XInitiator_InputArrayLastFunctionCallOutput(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"Use tool"}]},{"type":"function_call","call_id":"c1","name":"Read","arguments":"{}"},{"type":"function_call_output","call_id":"c1","output":"ok"}]}`)
|
||||||
|
e.applyHeaders(req, "token", body)
|
||||||
|
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||||
|
t.Fatalf("X-Initiator = %q, want agent (last item maps to tool role)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- Tests for x-github-api-version header (Problem M) ---
|
// --- Tests for x-github-api-version header (Problem M) ---
|
||||||
|
|
||||||
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||||
|
|||||||
@@ -49,15 +49,8 @@ const (
|
|||||||
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
|
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
|
||||||
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
|
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
|
||||||
|
|
||||||
// kiroUserAgent matches Amazon Q CLI style for User-Agent header
|
// kiroIDEAgentMode is the agent mode header value for Kiro IDE requests
|
||||||
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
|
kiroIDEAgentMode = "vibe"
|
||||||
// kiroFullUserAgent is the complete x-amz-user-agent header (Amazon Q CLI style)
|
|
||||||
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
|
|
||||||
|
|
||||||
// Kiro IDE style headers for IDC auth
|
|
||||||
kiroIDEUserAgent = "aws-sdk-js/1.0.27 ua/2.1 os/win32#10.0.19044 lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E"
|
|
||||||
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.27"
|
|
||||||
kiroIDEAgentModeVibe = "vibe"
|
|
||||||
|
|
||||||
// Socket retry configuration constants
|
// Socket retry configuration constants
|
||||||
// Maximum number of retry attempts for socket/network errors
|
// Maximum number of retry attempts for socket/network errors
|
||||||
@@ -87,20 +80,13 @@ var (
|
|||||||
usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
|
usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
|
||||||
)
|
)
|
||||||
|
|
||||||
// Global FingerprintManager for dynamic User-Agent generation per token
|
// endpointAliases maps user preference values to canonical endpoint names.
|
||||||
// Each token gets a unique fingerprint on first use, which is cached for subsequent requests
|
var endpointAliases = map[string]string{
|
||||||
var (
|
"codewhisperer": "codewhisperer",
|
||||||
globalFingerprintManager *kiroauth.FingerprintManager
|
"ide": "codewhisperer",
|
||||||
globalFingerprintManagerOnce sync.Once
|
"amazonq": "amazonq",
|
||||||
)
|
"q": "amazonq",
|
||||||
|
"cli": "amazonq",
|
||||||
// getGlobalFingerprintManager returns the global FingerprintManager instance
|
|
||||||
func getGlobalFingerprintManager() *kiroauth.FingerprintManager {
|
|
||||||
globalFingerprintManagerOnce.Do(func() {
|
|
||||||
globalFingerprintManager = kiroauth.NewFingerprintManager()
|
|
||||||
log.Infof("kiro: initialized global FingerprintManager for dynamic UA generation")
|
|
||||||
})
|
|
||||||
return globalFingerprintManager
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// retryConfig holds configuration for socket retry logic.
|
// retryConfig holds configuration for socket retry logic.
|
||||||
@@ -433,87 +419,41 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
|||||||
return kiroEndpointConfigs
|
return kiroEndpointConfigs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine API region using shared resolution logic
|
|
||||||
region := resolveKiroAPIRegion(auth)
|
region := resolveKiroAPIRegion(auth)
|
||||||
|
log.Debugf("kiro: using region %s", region)
|
||||||
|
|
||||||
// Build endpoint configs for the specified region
|
configs := buildKiroEndpointConfigs(region)
|
||||||
endpointConfigs := buildKiroEndpointConfigs(region)
|
|
||||||
|
|
||||||
// For IDC auth, use Q endpoint with AI_EDITOR origin
|
|
||||||
// IDC tokens work with Q endpoint using Bearer auth
|
|
||||||
// The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC)
|
|
||||||
// NOT in how API calls are made - both Social and IDC use the same endpoint/origin
|
|
||||||
if auth.Metadata != nil {
|
|
||||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
|
||||||
if strings.ToLower(authMethod) == "idc" {
|
|
||||||
log.Debugf("kiro: IDC auth, using Q endpoint (region: %s)", region)
|
|
||||||
return endpointConfigs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for preference
|
|
||||||
var preference string
|
|
||||||
if auth.Metadata != nil {
|
|
||||||
if p, ok := auth.Metadata["preferred_endpoint"].(string); ok {
|
|
||||||
preference = p
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Check attributes as fallback (e.g. from HTTP headers)
|
|
||||||
if preference == "" && auth.Attributes != nil {
|
|
||||||
preference = auth.Attributes["preferred_endpoint"]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
preference := getAuthValue(auth, "preferred_endpoint")
|
||||||
if preference == "" {
|
if preference == "" {
|
||||||
return endpointConfigs
|
return configs
|
||||||
}
|
}
|
||||||
|
|
||||||
preference = strings.ToLower(strings.TrimSpace(preference))
|
targetName, ok := endpointAliases[preference]
|
||||||
|
if !ok {
|
||||||
|
return configs
|
||||||
|
}
|
||||||
|
|
||||||
// Create new slice to avoid modifying global state
|
var preferred, others []kiroEndpointConfig
|
||||||
var sorted []kiroEndpointConfig
|
for _, cfg := range configs {
|
||||||
var remaining []kiroEndpointConfig
|
if strings.ToLower(cfg.Name) == targetName {
|
||||||
|
preferred = append(preferred, cfg)
|
||||||
for _, cfg := range endpointConfigs {
|
|
||||||
name := strings.ToLower(cfg.Name)
|
|
||||||
// Check for matches
|
|
||||||
// CodeWhisperer aliases: codewhisperer, ide
|
|
||||||
// AmazonQ aliases: amazonq, q, cli
|
|
||||||
isMatch := false
|
|
||||||
if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" {
|
|
||||||
isMatch = true
|
|
||||||
} else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" {
|
|
||||||
isMatch = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if isMatch {
|
|
||||||
sorted = append(sorted, cfg)
|
|
||||||
} else {
|
} else {
|
||||||
remaining = append(remaining, cfg)
|
others = append(others, cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If preference didn't match anything, return default
|
if len(preferred) == 0 {
|
||||||
if len(sorted) == 0 {
|
return configs
|
||||||
return endpointConfigs
|
|
||||||
}
|
}
|
||||||
|
return append(preferred, others...)
|
||||||
// Combine: preferred first, then others
|
|
||||||
return append(sorted, remaining...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API.
|
// KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API.
|
||||||
type KiroExecutor struct {
|
type KiroExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
|
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
|
||||||
}
|
profileArnMu sync.Mutex // Serializes profileArn fetches to prevent concurrent map writes
|
||||||
|
|
||||||
// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
|
|
||||||
func isIDCAuth(auth *cliproxyauth.Auth) bool {
|
|
||||||
if auth == nil || auth.Metadata == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
|
||||||
return strings.ToLower(authMethod) == "idc"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
|
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
|
||||||
@@ -546,27 +486,22 @@ func NewKiroExecutor(cfg *config.Config) *KiroExecutor {
|
|||||||
// Identifier returns the unique identifier for this executor.
|
// Identifier returns the unique identifier for this executor.
|
||||||
func (e *KiroExecutor) Identifier() string { return "kiro" }
|
func (e *KiroExecutor) Identifier() string { return "kiro" }
|
||||||
|
|
||||||
// applyDynamicFingerprint applies token-specific fingerprint headers to the request
|
// applyDynamicFingerprint applies account-specific fingerprint headers to the request.
|
||||||
// For IDC auth, uses dynamic fingerprint-based User-Agent
|
|
||||||
// For other auth types, uses static Amazon Q CLI style headers
|
|
||||||
func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) {
|
func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) {
|
||||||
if isIDCAuth(auth) {
|
accountKey := getAccountKey(auth)
|
||||||
// Get token-specific fingerprint for dynamic UA generation
|
fp := kiroauth.GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||||
tokenKey := getTokenKey(auth)
|
|
||||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
|
||||||
|
|
||||||
// Use fingerprint-generated dynamic User-Agent
|
req.Header.Set("User-Agent", fp.BuildUserAgent())
|
||||||
req.Header.Set("User-Agent", fp.BuildUserAgent())
|
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
|
||||||
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
|
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
|
||||||
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
|
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||||
|
|
||||||
log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)",
|
keyPrefix := accountKey
|
||||||
tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
|
if len(keyPrefix) > 8 {
|
||||||
} else {
|
keyPrefix = keyPrefix[:8]
|
||||||
// Use static Amazon Q CLI style headers for non-IDC auth
|
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
|
||||||
req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
|
||||||
}
|
}
|
||||||
|
log.Debugf("kiro: using dynamic fingerprint for account %s (SDK:%s, OS:%s/%s, Kiro:%s)",
|
||||||
|
keyPrefix+"...", fp.StreamingSDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrepareRequest prepares the HTTP request before execution.
|
// PrepareRequest prepares the HTTP request before execution.
|
||||||
@@ -609,17 +544,51 @@ func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTokenKey returns a unique key for rate limiting based on auth credentials.
|
// getAccountKey returns a stable account key for fingerprint lookup and rate limiting.
|
||||||
// Uses auth ID if available, otherwise falls back to a hash of the access token.
|
// Fallback order:
|
||||||
func getTokenKey(auth *cliproxyauth.Auth) string {
|
// 1) client_id / refresh_token (best account identity)
|
||||||
|
// 2) auth.ID (stable local auth record)
|
||||||
|
// 3) profile_arn (stable AWS profile identity)
|
||||||
|
// 4) access_token (least preferred but deterministic)
|
||||||
|
// 5) fixed anonymous seed
|
||||||
|
func getAccountKey(auth *cliproxyauth.Auth) string {
|
||||||
|
var clientID, refreshToken, profileArn string
|
||||||
|
if auth != nil && auth.Metadata != nil {
|
||||||
|
clientID, _ = auth.Metadata["client_id"].(string)
|
||||||
|
refreshToken, _ = auth.Metadata["refresh_token"].(string)
|
||||||
|
profileArn, _ = auth.Metadata["profile_arn"].(string)
|
||||||
|
}
|
||||||
|
if clientID != "" || refreshToken != "" {
|
||||||
|
return kiroauth.GetAccountKey(clientID, refreshToken)
|
||||||
|
}
|
||||||
if auth != nil && auth.ID != "" {
|
if auth != nil && auth.ID != "" {
|
||||||
return auth.ID
|
return kiroauth.GenerateAccountKey(auth.ID)
|
||||||
}
|
}
|
||||||
accessToken, _ := kiroCredentials(auth)
|
if profileArn != "" {
|
||||||
if len(accessToken) > 16 {
|
return kiroauth.GenerateAccountKey(profileArn)
|
||||||
return accessToken[:16]
|
|
||||||
}
|
}
|
||||||
return accessToken
|
if accessToken, _ := kiroCredentials(auth); accessToken != "" {
|
||||||
|
return kiroauth.GenerateAccountKey(accessToken)
|
||||||
|
}
|
||||||
|
return kiroauth.GenerateAccountKey("kiro-anonymous")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAuthValue looks up a value by key in auth Metadata, then Attributes.
|
||||||
|
func getAuthValue(auth *cliproxyauth.Auth, key string) string {
|
||||||
|
if auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if auth.Metadata != nil {
|
||||||
|
if v, ok := auth.Metadata[key].(string); ok && v != "" {
|
||||||
|
return strings.ToLower(strings.TrimSpace(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
if v := auth.Attributes[key]; v != "" {
|
||||||
|
return strings.ToLower(strings.TrimSpace(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute sends the request to Kiro API and returns the response.
|
// Execute sends the request to Kiro API and returns the response.
|
||||||
@@ -631,7 +600,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Rate limiting: get token key for tracking
|
// Rate limiting: get token key for tracking
|
||||||
tokenKey := getTokenKey(auth)
|
tokenKey := getAccountKey(auth)
|
||||||
rateLimiter := kiroauth.GetGlobalRateLimiter()
|
rateLimiter := kiroauth.GetGlobalRateLimiter()
|
||||||
cooldownMgr := kiroauth.GetGlobalCooldownManager()
|
cooldownMgr := kiroauth.GetGlobalCooldownManager()
|
||||||
|
|
||||||
@@ -693,6 +662,13 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
|
|
||||||
kiroModelID := e.mapModelToKiro(req.Model)
|
kiroModelID := e.mapModelToKiro(req.Model)
|
||||||
|
|
||||||
|
// Fetch profileArn if missing (for imported accounts from Kiro IDE)
|
||||||
|
if profileArn == "" {
|
||||||
|
if fetched := e.fetchAndSaveProfileArn(ctx, auth, accessToken); fetched != "" {
|
||||||
|
profileArn = fetched
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Determine agentic mode and effective profile ARN using helper functions
|
// Determine agentic mode and effective profile ARN using helper functions
|
||||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||||
@@ -749,7 +725,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||||
}
|
}
|
||||||
// Kiro-specific headers
|
// Kiro-specific headers
|
||||||
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
|
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
|
||||||
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
|
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||||
|
|
||||||
// Apply dynamic fingerprint-based headers
|
// Apply dynamic fingerprint-based headers
|
||||||
@@ -1060,7 +1036,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Rate limiting: get token key for tracking
|
// Rate limiting: get token key for tracking
|
||||||
tokenKey := getTokenKey(auth)
|
tokenKey := getAccountKey(auth)
|
||||||
rateLimiter := kiroauth.GetGlobalRateLimiter()
|
rateLimiter := kiroauth.GetGlobalRateLimiter()
|
||||||
cooldownMgr := kiroauth.GetGlobalCooldownManager()
|
cooldownMgr := kiroauth.GetGlobalCooldownManager()
|
||||||
|
|
||||||
@@ -1126,6 +1102,13 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
|
|
||||||
kiroModelID := e.mapModelToKiro(req.Model)
|
kiroModelID := e.mapModelToKiro(req.Model)
|
||||||
|
|
||||||
|
// Fetch profileArn if missing (for imported accounts from Kiro IDE)
|
||||||
|
if profileArn == "" {
|
||||||
|
if fetched := e.fetchAndSaveProfileArn(ctx, auth, accessToken); fetched != "" {
|
||||||
|
profileArn = fetched
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Determine agentic mode and effective profile ARN using helper functions
|
// Determine agentic mode and effective profile ARN using helper functions
|
||||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||||
@@ -1185,7 +1168,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
|||||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||||
}
|
}
|
||||||
// Kiro-specific headers
|
// Kiro-specific headers
|
||||||
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
|
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
|
||||||
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
|
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||||
|
|
||||||
// Apply dynamic fingerprint-based headers
|
// Apply dynamic fingerprint-based headers
|
||||||
@@ -1647,62 +1630,23 @@ func determineAgenticMode(model string) (isAgentic, isChatOnly bool) {
|
|||||||
return isAgentic, isChatOnly
|
return isAgentic, isChatOnly
|
||||||
}
|
}
|
||||||
|
|
||||||
// getEffectiveProfileArn determines if profileArn should be included based on auth method.
|
// getEffectiveProfileArnWithWarning suppresses profileArn for builder-id and AWS SSO OIDC auth.
|
||||||
// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC).
|
// Builder-id users (auth_method == "builder-id") and AWS SSO OIDC users (auth_type == "aws_sso_oidc")
|
||||||
//
|
// don't need profileArn — sending it causes 403 errors.
|
||||||
// Detection logic (matching kiro-openai-gateway):
|
// For all other auth methods (e.g. social auth), profileArn is returned as-is,
|
||||||
// 1. Check auth_method field: "builder-id" or "idc"
|
// with a warning logged if it is empty.
|
||||||
// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
|
|
||||||
// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
|
|
||||||
func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
|
|
||||||
if auth != nil && auth.Metadata != nil {
|
|
||||||
// Check 1: auth_method field (from CLIProxyAPI tokens)
|
|
||||||
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
|
|
||||||
return "" // AWS SSO OIDC - don't include profileArn
|
|
||||||
}
|
|
||||||
// Check 2: auth_type field (from kiro-cli tokens)
|
|
||||||
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
|
|
||||||
return "" // AWS SSO OIDC - don't include profileArn
|
|
||||||
}
|
|
||||||
// Check 3: client_id + client_secret presence (AWS SSO OIDC signature)
|
|
||||||
_, hasClientID := auth.Metadata["client_id"].(string)
|
|
||||||
_, hasClientSecret := auth.Metadata["client_secret"].(string)
|
|
||||||
if hasClientID && hasClientSecret {
|
|
||||||
return "" // AWS SSO OIDC - don't include profileArn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return profileArn
|
|
||||||
}
|
|
||||||
|
|
||||||
// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method,
|
|
||||||
// and logs a warning if profileArn is missing for non-builder-id auth.
|
|
||||||
// This consolidates the auth_method check that was previously done separately.
|
|
||||||
//
|
|
||||||
// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors.
|
|
||||||
// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn.
|
|
||||||
//
|
|
||||||
// Detection logic (matching kiro-openai-gateway):
|
|
||||||
// 1. Check auth_method field: "builder-id" or "idc"
|
|
||||||
// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
|
|
||||||
// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
|
|
||||||
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
||||||
if auth != nil && auth.Metadata != nil {
|
if auth != nil && auth.Metadata != nil {
|
||||||
// Check 1: auth_method field (from CLIProxyAPI tokens)
|
// Check 1: auth_method field, skip for builder-id only
|
||||||
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
|
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
|
||||||
return "" // AWS SSO OIDC - don't include profileArn
|
return ""
|
||||||
}
|
}
|
||||||
// Check 2: auth_type field (from kiro-cli tokens)
|
// Check 2: auth_type field (from kiro-cli tokens)
|
||||||
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
|
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
|
||||||
return "" // AWS SSO OIDC - don't include profileArn
|
return "" // AWS SSO OIDC - don't include profileArn
|
||||||
}
|
}
|
||||||
// Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway)
|
|
||||||
_, hasClientID := auth.Metadata["client_id"].(string)
|
|
||||||
_, hasClientSecret := auth.Metadata["client_secret"].(string)
|
|
||||||
if hasClientID && hasClientSecret {
|
|
||||||
return "" // AWS SSO OIDC - don't include profileArn
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// For social auth (Kiro Desktop), profileArn is required
|
// For social auth and IDC, profileArn is required
|
||||||
if profileArn == "" {
|
if profileArn == "" {
|
||||||
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
||||||
}
|
}
|
||||||
@@ -2514,7 +2458,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers
|
reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers
|
||||||
var totalUsage usage.Detail
|
var totalUsage usage.Detail
|
||||||
var hasToolUses bool // Track if any tool uses were emitted
|
var hasToolUses bool // Track if any tool uses were emitted
|
||||||
var hasTruncatedTools bool // Track if any tool uses were truncated
|
|
||||||
var upstreamStopReason string // Track stop_reason from upstream events
|
var upstreamStopReason string // Track stop_reason from upstream events
|
||||||
|
|
||||||
// Tool use state tracking for input buffering and deduplication
|
// Tool use state tracking for input buffering and deduplication
|
||||||
@@ -3342,59 +3285,9 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
|
|
||||||
// Emit completed tool uses
|
// Emit completed tool uses
|
||||||
for _, tu := range completedToolUses {
|
for _, tu := range completedToolUses {
|
||||||
// Check if this tool was truncated - emit with SOFT_LIMIT_REACHED marker
|
// Skip truncated tools - don't emit fake marker tool_use
|
||||||
if tu.IsTruncated {
|
if tu.IsTruncated {
|
||||||
hasTruncatedTools = true
|
log.Warnf("kiro: streamToChannel skipping truncated tool: %s (ID: %s)", tu.Name, tu.ToolUseID)
|
||||||
log.Infof("kiro: streamToChannel emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", tu.Name, tu.ToolUseID)
|
|
||||||
|
|
||||||
// Close text block if open
|
|
||||||
if isTextBlockOpen && contentBlockIndex >= 0 {
|
|
||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
|
||||||
for _, chunk := range sseData {
|
|
||||||
if chunk != "" {
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
isTextBlockOpen = false
|
|
||||||
}
|
|
||||||
|
|
||||||
contentBlockIndex++
|
|
||||||
|
|
||||||
// Emit tool_use with SOFT_LIMIT_REACHED marker input
|
|
||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
|
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
|
||||||
for _, chunk := range sseData {
|
|
||||||
if chunk != "" {
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build SOFT_LIMIT_REACHED marker input
|
|
||||||
markerInput := map[string]interface{}{
|
|
||||||
"_status": "SOFT_LIMIT_REACHED",
|
|
||||||
"_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.",
|
|
||||||
}
|
|
||||||
|
|
||||||
markerJSON, _ := json.Marshal(markerInput)
|
|
||||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(markerJSON), contentBlockIndex)
|
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
|
||||||
for _, chunk := range sseData {
|
|
||||||
if chunk != "" {
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close tool_use block
|
|
||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
|
||||||
for _, chunk := range sseData {
|
|
||||||
if chunk != "" {
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hasToolUses = true // Keep this so stop_reason = tool_use
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3696,12 +3589,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Determine stop reason: prefer upstream, then detect tool_use, default to end_turn
|
// Determine stop reason: prefer upstream, then detect tool_use, default to end_turn
|
||||||
// SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop
|
|
||||||
stopReason := upstreamStopReason
|
stopReason := upstreamStopReason
|
||||||
if hasTruncatedTools {
|
|
||||||
// Log that we're using SOFT_LIMIT_REACHED approach
|
|
||||||
log.Infof("kiro: streamToChannel using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use for truncated tools")
|
|
||||||
}
|
|
||||||
if stopReason == "" {
|
if stopReason == "" {
|
||||||
if hasToolUses {
|
if hasToolUses {
|
||||||
stopReason = "tool_use"
|
stopReason = "tool_use"
|
||||||
@@ -3999,6 +3887,51 @@ func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fetchAndSaveProfileArn fetches profileArn from API if missing, updates auth and persists to file.
|
||||||
|
func (e *KiroExecutor) fetchAndSaveProfileArn(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) string {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip for Builder ID - they don't have profiles
|
||||||
|
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
|
||||||
|
log.Debugf("kiro executor: skipping profileArn fetch for builder-id auth")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
e.profileArnMu.Lock()
|
||||||
|
defer e.profileArnMu.Unlock()
|
||||||
|
|
||||||
|
// Double-check: another goroutine may have already fetched and saved the profileArn
|
||||||
|
if arn, ok := auth.Metadata["profile_arn"].(string); ok && arn != "" {
|
||||||
|
return arn
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID, _ := auth.Metadata["client_id"].(string)
|
||||||
|
refreshToken, _ := auth.Metadata["refresh_token"].(string)
|
||||||
|
|
||||||
|
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
|
||||||
|
profileArn := ssoClient.FetchProfileArn(ctx, accessToken, clientID, refreshToken)
|
||||||
|
if profileArn == "" {
|
||||||
|
log.Debugf("kiro executor: FetchProfileArn returned no profiles")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.Metadata["profile_arn"] = profileArn
|
||||||
|
if auth.Attributes == nil {
|
||||||
|
auth.Attributes = make(map[string]string)
|
||||||
|
}
|
||||||
|
auth.Attributes["profile_arn"] = profileArn
|
||||||
|
|
||||||
|
if err := e.persistRefreshedAuth(auth); err != nil {
|
||||||
|
log.Warnf("kiro executor: failed to persist profileArn: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Infof("kiro executor: fetched and saved profileArn: %s", profileArn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return profileArn
|
||||||
|
}
|
||||||
|
|
||||||
// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制)
|
// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制)
|
||||||
// 当内存中的 token 已过期时,尝试从文件读取最新的 token
|
// 当内存中的 token 已过期时,尝试从文件读取最新的 token
|
||||||
// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题
|
// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题
|
||||||
@@ -4728,7 +4661,7 @@ func (e *KiroExecutor) callKiroAndBuffer(
|
|||||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||||
|
|
||||||
tokenKey := getTokenKey(auth)
|
tokenKey := getAccountKey(auth)
|
||||||
|
|
||||||
kiroStream, err := e.executeStreamWithRetry(
|
kiroStream, err := e.executeStreamWithRetry(
|
||||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||||
@@ -4770,7 +4703,7 @@ func (e *KiroExecutor) callKiroDirectStream(
|
|||||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||||
|
|
||||||
tokenKey := getTokenKey(auth)
|
tokenKey := getAccountKey(auth)
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
var streamErr error
|
var streamErr error
|
||||||
@@ -4819,7 +4752,7 @@ func (e *KiroExecutor) executeNonStreamFallback(
|
|||||||
kiroModelID := e.mapModelToKiro(req.Model)
|
kiroModelID := e.mapModelToKiro(req.Model)
|
||||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||||
tokenKey := getTokenKey(auth)
|
tokenKey := getAccountKey(auth)
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
var err error
|
var err error
|
||||||
|
|||||||
423
internal/runtime/executor/kiro_executor_test.go
Normal file
423
internal/runtime/executor/kiro_executor_test.go
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildKiroEndpointConfigs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
region string
|
||||||
|
expectedURL string
|
||||||
|
expectedOrigin string
|
||||||
|
expectedName string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty region - defaults to us-east-1",
|
||||||
|
region: "",
|
||||||
|
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||||
|
expectedOrigin: "AI_EDITOR",
|
||||||
|
expectedName: "AmazonQ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "us-east-1",
|
||||||
|
region: "us-east-1",
|
||||||
|
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||||
|
expectedOrigin: "AI_EDITOR",
|
||||||
|
expectedName: "AmazonQ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ap-southeast-1",
|
||||||
|
region: "ap-southeast-1",
|
||||||
|
expectedURL: "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse",
|
||||||
|
expectedOrigin: "AI_EDITOR",
|
||||||
|
expectedName: "AmazonQ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "eu-west-1",
|
||||||
|
region: "eu-west-1",
|
||||||
|
expectedURL: "https://q.eu-west-1.amazonaws.com/generateAssistantResponse",
|
||||||
|
expectedOrigin: "AI_EDITOR",
|
||||||
|
expectedName: "AmazonQ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
configs := buildKiroEndpointConfigs(tt.region)
|
||||||
|
|
||||||
|
if len(configs) != 2 {
|
||||||
|
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check primary endpoint (AmazonQ)
|
||||||
|
primary := configs[0]
|
||||||
|
if primary.URL != tt.expectedURL {
|
||||||
|
t.Errorf("primary URL = %q, want %q", primary.URL, tt.expectedURL)
|
||||||
|
}
|
||||||
|
if primary.Origin != tt.expectedOrigin {
|
||||||
|
t.Errorf("primary Origin = %q, want %q", primary.Origin, tt.expectedOrigin)
|
||||||
|
}
|
||||||
|
if primary.Name != tt.expectedName {
|
||||||
|
t.Errorf("primary Name = %q, want %q", primary.Name, tt.expectedName)
|
||||||
|
}
|
||||||
|
if primary.AmzTarget != "" {
|
||||||
|
t.Errorf("primary AmzTarget should be empty, got %q", primary.AmzTarget)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check fallback endpoint (CodeWhisperer)
|
||||||
|
fallback := configs[1]
|
||||||
|
if fallback.Name != "CodeWhisperer" {
|
||||||
|
t.Errorf("fallback Name = %q, want %q", fallback.Name, "CodeWhisperer")
|
||||||
|
}
|
||||||
|
// CodeWhisperer fallback uses the same region as Q endpoint
|
||||||
|
expectedRegion := tt.region
|
||||||
|
if expectedRegion == "" {
|
||||||
|
expectedRegion = kiroDefaultRegion
|
||||||
|
}
|
||||||
|
expectedFallbackURL := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", expectedRegion)
|
||||||
|
if fallback.URL != expectedFallbackURL {
|
||||||
|
t.Errorf("fallback URL = %q, want %q", fallback.URL, expectedFallbackURL)
|
||||||
|
}
|
||||||
|
if fallback.AmzTarget == "" {
|
||||||
|
t.Error("fallback AmzTarget should NOT be empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKiroEndpointConfigs_NilAuth(t *testing.T) {
|
||||||
|
configs := getKiroEndpointConfigs(nil)
|
||||||
|
|
||||||
|
if len(configs) != 2 {
|
||||||
|
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return default us-east-1 configs
|
||||||
|
if configs[0].Name != "AmazonQ" {
|
||||||
|
t.Errorf("first config Name = %q, want %q", configs[0].Name, "AmazonQ")
|
||||||
|
}
|
||||||
|
expectedURL := "https://q.us-east-1.amazonaws.com/generateAssistantResponse"
|
||||||
|
if configs[0].URL != expectedURL {
|
||||||
|
t.Errorf("first config URL = %q, want %q", configs[0].URL, expectedURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKiroEndpointConfigs_WithRegionFromProfileArn(t *testing.T) {
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
configs := getKiroEndpointConfigs(auth)
|
||||||
|
|
||||||
|
if len(configs) != 2 {
|
||||||
|
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedURL := "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse"
|
||||||
|
if configs[0].URL != expectedURL {
|
||||||
|
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKiroEndpointConfigs_WithApiRegionOverride(t *testing.T) {
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"api_region": "eu-central-1",
|
||||||
|
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
configs := getKiroEndpointConfigs(auth)
|
||||||
|
|
||||||
|
// api_region should take precedence over profile_arn
|
||||||
|
expectedURL := "https://q.eu-central-1.amazonaws.com/generateAssistantResponse"
|
||||||
|
if configs[0].URL != expectedURL {
|
||||||
|
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKiroEndpointConfigs_PreferredEndpoint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
preference string
|
||||||
|
expectedFirstName string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Prefer codewhisperer",
|
||||||
|
preference: "codewhisperer",
|
||||||
|
expectedFirstName: "CodeWhisperer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Prefer ide (alias for codewhisperer)",
|
||||||
|
preference: "ide",
|
||||||
|
expectedFirstName: "CodeWhisperer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Prefer amazonq",
|
||||||
|
preference: "amazonq",
|
||||||
|
expectedFirstName: "AmazonQ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Prefer q (alias for amazonq)",
|
||||||
|
preference: "q",
|
||||||
|
expectedFirstName: "AmazonQ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Prefer cli (alias for amazonq)",
|
||||||
|
preference: "cli",
|
||||||
|
expectedFirstName: "AmazonQ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unknown preference - no reordering",
|
||||||
|
preference: "unknown",
|
||||||
|
expectedFirstName: "AmazonQ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty preference - no reordering",
|
||||||
|
preference: "",
|
||||||
|
expectedFirstName: "AmazonQ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"preferred_endpoint": tt.preference,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
configs := getKiroEndpointConfigs(auth)
|
||||||
|
|
||||||
|
if configs[0].Name != tt.expectedFirstName {
|
||||||
|
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, tt.expectedFirstName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKiroEndpointConfigs_PreferredEndpointFromAttributes(t *testing.T) {
|
||||||
|
// Test that preferred_endpoint can also come from Attributes
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{},
|
||||||
|
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
configs := getKiroEndpointConfigs(auth)
|
||||||
|
|
||||||
|
if configs[0].Name != "CodeWhisperer" {
|
||||||
|
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "CodeWhisperer")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKiroEndpointConfigs_MetadataTakesPrecedenceOverAttributes(t *testing.T) {
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{"preferred_endpoint": "amazonq"},
|
||||||
|
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
configs := getKiroEndpointConfigs(auth)
|
||||||
|
|
||||||
|
// Metadata should take precedence
|
||||||
|
if configs[0].Name != "AmazonQ" {
|
||||||
|
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "AmazonQ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAuthValue(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
auth *cliproxyauth.Auth
|
||||||
|
key string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "From metadata",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{"test_key": "metadata_value"},
|
||||||
|
},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "metadata_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "From attributes (fallback)",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||||
|
},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "attribute_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Metadata takes precedence",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{"test_key": "metadata_value"},
|
||||||
|
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||||
|
},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "metadata_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Key not found",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{"other_key": "value"},
|
||||||
|
Attributes: map[string]string{"another_key": "value"},
|
||||||
|
},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil metadata",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||||
|
},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "attribute_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Both nil",
|
||||||
|
auth: &cliproxyauth.Auth{},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Value is trimmed and lowercased",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{"test_key": " UPPER_VALUE "},
|
||||||
|
},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "upper_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty string value in metadata - falls back to attributes",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{"test_key": ""},
|
||||||
|
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||||
|
},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "attribute_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-string value in metadata - falls back to attributes",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{"test_key": 123},
|
||||||
|
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||||
|
},
|
||||||
|
key: "test_key",
|
||||||
|
expected: "attribute_value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := getAuthValue(tt.auth, tt.key)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("getAuthValue() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
auth *cliproxyauth.Auth
|
||||||
|
checkFn func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "From client_id",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"client_id": "test-client-id-123",
|
||||||
|
"refresh_token": "test-refresh-token-456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
checkFn: func(t *testing.T, result string) {
|
||||||
|
expected := kiroauth.GetAccountKey("test-client-id-123", "test-refresh-token-456")
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "From refresh_token only",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"refresh_token": "test-refresh-token-789",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
checkFn: func(t *testing.T, result string) {
|
||||||
|
expected := kiroauth.GetAccountKey("", "test-refresh-token-789")
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil auth",
|
||||||
|
auth: nil,
|
||||||
|
checkFn: func(t *testing.T, result string) {
|
||||||
|
if len(result) != 16 {
|
||||||
|
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil metadata",
|
||||||
|
auth: &cliproxyauth.Auth{},
|
||||||
|
checkFn: func(t *testing.T, result string) {
|
||||||
|
if len(result) != 16 {
|
||||||
|
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty metadata",
|
||||||
|
auth: &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{},
|
||||||
|
},
|
||||||
|
checkFn: func(t *testing.T, result string) {
|
||||||
|
if len(result) != 16 {
|
||||||
|
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := getAccountKey(tt.auth)
|
||||||
|
tt.checkFn(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEndpointAliases(t *testing.T) {
|
||||||
|
// Verify all expected aliases are defined
|
||||||
|
expectedAliases := map[string]string{
|
||||||
|
"codewhisperer": "codewhisperer",
|
||||||
|
"ide": "codewhisperer",
|
||||||
|
"amazonq": "amazonq",
|
||||||
|
"q": "amazonq",
|
||||||
|
"cli": "amazonq",
|
||||||
|
}
|
||||||
|
|
||||||
|
for alias, target := range expectedAliases {
|
||||||
|
if actual, ok := endpointAliases[alias]; !ok {
|
||||||
|
t.Errorf("missing alias %q", alias)
|
||||||
|
} else if actual != target {
|
||||||
|
t.Errorf("alias %q = %q, want %q", alias, actual, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no unexpected aliases
|
||||||
|
if len(endpointAliases) != len(expectedAliases) {
|
||||||
|
t.Errorf("unexpected number of aliases: got %d, want %d", len(endpointAliases), len(expectedAliases))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
@@ -22,9 +23,151 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||||
|
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||||
|
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
||||||
|
var qwenBeijingLoc = func() *time.Location {
|
||||||
|
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||||
|
if err != nil || loc == nil {
|
||||||
|
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
||||||
|
return time.FixedZone("CST", 8*3600)
|
||||||
|
}
|
||||||
|
return loc
|
||||||
|
}()
|
||||||
|
|
||||||
|
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||||
|
var qwenQuotaCodes = map[string]struct{}{
|
||||||
|
"insufficient_quota": {},
|
||||||
|
"quota_exceeded": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||||
|
// Qwen has a limit of 60 requests per minute per account.
|
||||||
|
var qwenRateLimiter = struct {
|
||||||
|
sync.Mutex
|
||||||
|
requests map[string][]time.Time // authID -> request timestamps
|
||||||
|
}{
|
||||||
|
requests: make(map[string][]time.Time),
|
||||||
|
}
|
||||||
|
|
||||||
|
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||||
|
// Keeps a small prefix/suffix to allow correlation across events.
|
||||||
|
func redactAuthID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(id) <= 8 {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return id[:4] + "..." + id[len(id)-4:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||||
|
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||||
|
func checkQwenRateLimit(authID string) error {
|
||||||
|
if authID == "" {
|
||||||
|
// Empty authID should not bypass rate limiting in production
|
||||||
|
// Use debug level to avoid log spam for certain auth flows
|
||||||
|
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
windowStart := now.Add(-qwenRateLimitWindow)
|
||||||
|
|
||||||
|
qwenRateLimiter.Lock()
|
||||||
|
defer qwenRateLimiter.Unlock()
|
||||||
|
|
||||||
|
// Get and filter timestamps within the window
|
||||||
|
timestamps := qwenRateLimiter.requests[authID]
|
||||||
|
var validTimestamps []time.Time
|
||||||
|
for _, ts := range timestamps {
|
||||||
|
if ts.After(windowStart) {
|
||||||
|
validTimestamps = append(validTimestamps, ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always prune expired entries to prevent memory leak
|
||||||
|
// Delete empty entries, otherwise update with pruned slice
|
||||||
|
if len(validTimestamps) == 0 {
|
||||||
|
delete(qwenRateLimiter.requests, authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if rate limit exceeded
|
||||||
|
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||||
|
// Calculate when the oldest request will expire
|
||||||
|
oldestInWindow := validTimestamps[0]
|
||||||
|
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||||
|
if retryAfter < time.Second {
|
||||||
|
retryAfter = time.Second
|
||||||
|
}
|
||||||
|
retryAfterSec := int(retryAfter.Seconds())
|
||||||
|
return statusErr{
|
||||||
|
code: http.StatusTooManyRequests,
|
||||||
|
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||||
|
retryAfter: &retryAfter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record this request and update the map with pruned timestamps
|
||||||
|
validTimestamps = append(validTimestamps, now)
|
||||||
|
qwenRateLimiter.requests[authID] = validTimestamps
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||||
|
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||||
|
func isQwenQuotaError(body []byte) bool {
|
||||||
|
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||||
|
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||||
|
|
||||||
|
// Primary check: exact match on error.code or error.type (most reliable)
|
||||||
|
if _, ok := qwenQuotaCodes[code]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: check message only if code/type don't match (less reliable)
|
||||||
|
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||||
|
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||||
|
strings.Contains(msg, "free allocated quota exceeded") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||||
|
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||||
|
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||||
|
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||||
|
errCode = httpCode
|
||||||
|
// Only check quota errors for expected status codes to avoid false positives
|
||||||
|
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||||
|
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||||
|
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||||
|
cooldown := timeUntilNextDay()
|
||||||
|
retryAfter = &cooldown
|
||||||
|
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||||
|
}
|
||||||
|
return errCode, retryAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
||||||
|
// Qwen's daily quota resets at 00:00 Beijing time.
|
||||||
|
func timeUntilNextDay() time.Duration {
|
||||||
|
now := time.Now()
|
||||||
|
nowLocal := now.In(qwenBeijingLoc)
|
||||||
|
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
||||||
|
return tomorrow.Sub(now)
|
||||||
|
}
|
||||||
|
|
||||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||||
type QwenExecutor struct {
|
type QwenExecutor struct {
|
||||||
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, false)
|
applyQwenHeaders(httpReq, token, false)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -158,6 +313,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, true)
|
applyQwenHeaders(httpReq, token, true)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -228,11 +393,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
|||||||
@@ -293,7 +293,7 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
|
|||||||
if config.Mode != ModeLevel {
|
if config.Mode != ModeLevel {
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
if !isBudgetBasedProvider(toFormat) || !isLevelBasedProvider(fromFormat) {
|
if !isBudgetCapableProvider(toFormat) {
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
budget, ok := ConvertLevelToBudget(string(config.Level))
|
budget, ok := ConvertLevelToBudget(string(config.Level))
|
||||||
@@ -353,6 +353,26 @@ func extractClaudeConfig(body []byte) ThinkingConfig {
|
|||||||
if thinkingType == "disabled" {
|
if thinkingType == "disabled" {
|
||||||
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
}
|
}
|
||||||
|
if thinkingType == "adaptive" || thinkingType == "auto" {
|
||||||
|
// Claude adaptive thinking uses output_config.effort (low/medium/high/max).
|
||||||
|
// We only treat it as a thinking config when effort is explicitly present;
|
||||||
|
// otherwise we passthrough and let upstream defaults apply.
|
||||||
|
if effort := gjson.GetBytes(body, "output_config.effort"); effort.Exists() && effort.Type == gjson.String {
|
||||||
|
value := strings.ToLower(strings.TrimSpace(effort.String()))
|
||||||
|
if value == "" {
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
switch value {
|
||||||
|
case "none":
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
case "auto":
|
||||||
|
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
|
||||||
|
default:
|
||||||
|
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
// Check budget_tokens
|
// Check budget_tokens
|
||||||
if budget := gjson.GetBytes(body, "thinking.budget_tokens"); budget.Exists() {
|
if budget := gjson.GetBytes(body, "thinking.budget_tokens"); budget.Exists() {
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ var levelToBudgetMap = map[string]int{
|
|||||||
"medium": 8192,
|
"medium": 8192,
|
||||||
"high": 24576,
|
"high": 24576,
|
||||||
"xhigh": 32768,
|
"xhigh": 32768,
|
||||||
|
// "max" is used by Claude adaptive thinking effort. We map it to a large budget
|
||||||
|
// and rely on per-model clamping when converting to budget-only providers.
|
||||||
|
"max": 128000,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertLevelToBudget converts a thinking level to a budget value.
|
// ConvertLevelToBudget converts a thinking level to a budget value.
|
||||||
@@ -31,6 +34,7 @@ var levelToBudgetMap = map[string]int{
|
|||||||
// - medium → 8192
|
// - medium → 8192
|
||||||
// - high → 24576
|
// - high → 24576
|
||||||
// - xhigh → 32768
|
// - xhigh → 32768
|
||||||
|
// - max → 128000
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - budget: The converted budget value
|
// - budget: The converted budget value
|
||||||
@@ -92,6 +96,43 @@ func ConvertBudgetToLevel(budget int) (string, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasLevel reports whether the given target level exists in the levels slice.
|
||||||
|
// Matching is case-insensitive with leading/trailing whitespace trimmed.
|
||||||
|
func HasLevel(levels []string, target string) bool {
|
||||||
|
for _, level := range levels {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(level), target) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapToClaudeEffort maps a generic thinking level string to a Claude adaptive
|
||||||
|
// thinking effort value (low/medium/high/max).
|
||||||
|
//
|
||||||
|
// supportsMax indicates whether the target model supports "max" effort.
|
||||||
|
// Returns the mapped effort and true if the level is valid, or ("", false) otherwise.
|
||||||
|
func MapToClaudeEffort(level string, supportsMax bool) (string, bool) {
|
||||||
|
level = strings.ToLower(strings.TrimSpace(level))
|
||||||
|
switch level {
|
||||||
|
case "":
|
||||||
|
return "", false
|
||||||
|
case "minimal":
|
||||||
|
return "low", true
|
||||||
|
case "low", "medium", "high":
|
||||||
|
return level, true
|
||||||
|
case "xhigh", "max":
|
||||||
|
if supportsMax {
|
||||||
|
return "max", true
|
||||||
|
}
|
||||||
|
return "high", true
|
||||||
|
case "auto":
|
||||||
|
return "high", true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ModelCapability describes the thinking format support of a model.
|
// ModelCapability describes the thinking format support of a model.
|
||||||
type ModelCapability int
|
type ModelCapability int
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
// Package claude implements thinking configuration scaffolding for Claude models.
|
// Package claude implements thinking configuration scaffolding for Claude models.
|
||||||
//
|
//
|
||||||
// Claude models use the thinking.budget_tokens format with values in the range
|
// Claude models support two thinking control styles:
|
||||||
// 1024-128000. Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5),
|
// - Manual thinking: thinking.type="enabled" with thinking.budget_tokens (token budget)
|
||||||
// while older models do not.
|
// - Adaptive thinking (Claude 4.6): thinking.type="adaptive" with output_config.effort (low/medium/high/max)
|
||||||
|
//
|
||||||
|
// Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5), while older models do not.
|
||||||
// See: _bmad-output/planning-artifacts/architecture.md#Epic-6
|
// See: _bmad-output/planning-artifacts/architecture.md#Epic-6
|
||||||
package claude
|
package claude
|
||||||
|
|
||||||
@@ -34,7 +36,11 @@ func init() {
|
|||||||
// - Budget clamping to model range
|
// - Budget clamping to model range
|
||||||
// - ZeroAllowed constraint enforcement
|
// - ZeroAllowed constraint enforcement
|
||||||
//
|
//
|
||||||
// Apply only processes ModeBudget and ModeNone; other modes are passed through unchanged.
|
// Apply processes:
|
||||||
|
// - ModeBudget: manual thinking budget_tokens
|
||||||
|
// - ModeLevel: adaptive thinking effort (Claude 4.6)
|
||||||
|
// - ModeAuto: provider default adaptive/manual behavior
|
||||||
|
// - ModeNone: disabled
|
||||||
//
|
//
|
||||||
// Expected output format when enabled:
|
// Expected output format when enabled:
|
||||||
//
|
//
|
||||||
@@ -45,6 +51,17 @@ func init() {
|
|||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
|
// Expected output format for adaptive:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "thinking": {
|
||||||
|
// "type": "adaptive"
|
||||||
|
// },
|
||||||
|
// "output_config": {
|
||||||
|
// "effort": "high"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
// Expected output format when disabled:
|
// Expected output format when disabled:
|
||||||
//
|
//
|
||||||
// {
|
// {
|
||||||
@@ -60,30 +77,91 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only process ModeBudget and ModeNone; other modes pass through
|
|
||||||
// (caller should use ValidateConfig first to normalize modes)
|
|
||||||
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone {
|
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
body = []byte(`{}`)
|
body = []byte(`{}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced)
|
supportsAdaptive := modelInfo != nil && modelInfo.Thinking != nil && len(modelInfo.Thinking.Levels) > 0
|
||||||
// Decide enabled/disabled based on budget value
|
|
||||||
if config.Budget == 0 {
|
switch config.Mode {
|
||||||
|
case thinking.ModeNone:
|
||||||
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
|
|
||||||
|
case thinking.ModeLevel:
|
||||||
|
// Adaptive thinking effort is only valid when the model advertises discrete levels.
|
||||||
|
// (Claude 4.6 uses output_config.effort.)
|
||||||
|
if supportsAdaptive && config.Level != "" {
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
result, _ = sjson.SetBytes(result, "output_config.effort", string(config.Level))
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback for non-adaptive Claude models: convert level to budget_tokens.
|
||||||
|
if budget, ok := thinking.ConvertLevelToBudget(string(config.Level)); ok {
|
||||||
|
config.Mode = thinking.ModeBudget
|
||||||
|
config.Budget = budget
|
||||||
|
config.Level = ""
|
||||||
|
} else {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
|
||||||
|
case thinking.ModeBudget:
|
||||||
|
// Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced).
|
||||||
|
// Decide enabled/disabled based on budget value.
|
||||||
|
if config.Budget == 0 {
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
|
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint).
|
||||||
|
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
||||||
|
return result, nil
|
||||||
|
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
// For Claude 4.6 models, auto maps to adaptive thinking with upstream defaults.
|
||||||
|
if supportsAdaptive {
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
// Explicit effort is optional for adaptive thinking; omit it to allow upstream default.
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy fallback: enable thinking without specifying budget_tokens.
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
|
||||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
|
||||||
|
|
||||||
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
|
|
||||||
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
|
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
|
||||||
@@ -141,7 +219,7 @@ func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto && config.Mode != thinking.ModeLevel {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,14 +231,36 @@ func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte,
|
|||||||
case thinking.ModeNone:
|
case thinking.ModeNone:
|
||||||
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
case thinking.ModeAuto:
|
case thinking.ModeAuto:
|
||||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
case thinking.ModeLevel:
|
||||||
|
// For user-defined models, interpret ModeLevel as Claude adaptive thinking effort.
|
||||||
|
// Upstream is responsible for validating whether the target model supports it.
|
||||||
|
if config.Level == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
result, _ = sjson.SetBytes(result, "output_config.effort", string(config.Level))
|
||||||
return result, nil
|
return result, nil
|
||||||
default:
|
default:
|
||||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config.effort")
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,6 @@
|
|||||||
package codex
|
package codex
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -68,7 +66,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
effort := ""
|
effort := ""
|
||||||
support := modelInfo.Thinking
|
support := modelInfo.Thinking
|
||||||
if config.Budget == 0 {
|
if config.Budget == 0 {
|
||||||
if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) {
|
if support.ZeroAllowed || thinking.HasLevel(support.Levels, string(thinking.LevelNone)) {
|
||||||
effort = string(thinking.LevelNone)
|
effort = string(thinking.LevelNone)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -120,12 +118,3 @@ func applyCompatibleCodex(body []byte, config thinking.ThinkingConfig) ([]byte,
|
|||||||
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
|
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasLevel(levels []string, target string) bool {
|
|
||||||
for _, level := range levels {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(level), target) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
||||||
//
|
//
|
||||||
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
|
// Kimi models use the OpenAI-compatible reasoning_effort format for enabled thinking
|
||||||
// (low/medium/high). The provider strips any existing thinking config and applies
|
// levels, but use thinking.type=disabled when thinking is explicitly turned off.
|
||||||
// the unified ThinkingConfig in OpenAI format.
|
|
||||||
package kimi
|
package kimi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -17,8 +16,8 @@ import (
|
|||||||
// Applier implements thinking.ProviderApplier for Kimi models.
|
// Applier implements thinking.ProviderApplier for Kimi models.
|
||||||
//
|
//
|
||||||
// Kimi-specific behavior:
|
// Kimi-specific behavior:
|
||||||
// - Output format: reasoning_effort (string: low/medium/high)
|
// - Enabled thinking: reasoning_effort (string levels)
|
||||||
// - Uses OpenAI-compatible format
|
// - Disabled thinking: thinking.type="disabled"
|
||||||
// - Supports budget-to-level conversion
|
// - Supports budget-to-level conversion
|
||||||
type Applier struct{}
|
type Applier struct{}
|
||||||
|
|
||||||
@@ -35,11 +34,19 @@ func init() {
|
|||||||
|
|
||||||
// Apply applies thinking configuration to Kimi request body.
|
// Apply applies thinking configuration to Kimi request body.
|
||||||
//
|
//
|
||||||
// Expected output format:
|
// Expected output format (enabled):
|
||||||
//
|
//
|
||||||
// {
|
// {
|
||||||
// "reasoning_effort": "high"
|
// "reasoning_effort": "high"
|
||||||
// }
|
// }
|
||||||
|
//
|
||||||
|
// Expected output format (disabled):
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "thinking": {
|
||||||
|
// "type": "disabled"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
if thinking.IsUserDefinedModel(modelInfo) {
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
return applyCompatibleKimi(body, config)
|
return applyCompatibleKimi(body, config)
|
||||||
@@ -60,8 +67,13 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
}
|
}
|
||||||
effort = string(config.Level)
|
effort = string(config.Level)
|
||||||
case thinking.ModeNone:
|
case thinking.ModeNone:
|
||||||
// Kimi uses "none" to disable thinking
|
// Respect clamped fallback level for models that cannot disable thinking.
|
||||||
effort = string(thinking.LevelNone)
|
if config.Level != "" && config.Level != thinking.LevelNone {
|
||||||
|
effort = string(config.Level)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Kimi requires explicit disabled thinking object.
|
||||||
|
return applyDisabledThinking(body)
|
||||||
case thinking.ModeBudget:
|
case thinking.ModeBudget:
|
||||||
// Convert budget to level using threshold mapping
|
// Convert budget to level using threshold mapping
|
||||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||||
@@ -79,12 +91,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
if effort == "" {
|
if effort == "" {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
return applyReasoningEffort(body, effort)
|
||||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
|
||||||
if err != nil {
|
|
||||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
||||||
@@ -101,7 +108,9 @@ func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, e
|
|||||||
}
|
}
|
||||||
effort = string(config.Level)
|
effort = string(config.Level)
|
||||||
case thinking.ModeNone:
|
case thinking.ModeNone:
|
||||||
effort = string(thinking.LevelNone)
|
if config.Level == "" || config.Level == thinking.LevelNone {
|
||||||
|
return applyDisabledThinking(body)
|
||||||
|
}
|
||||||
if config.Level != "" {
|
if config.Level != "" {
|
||||||
effort = string(config.Level)
|
effort = string(config.Level)
|
||||||
}
|
}
|
||||||
@@ -118,9 +127,33 @@ func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, e
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
return applyReasoningEffort(body, effort)
|
||||||
if err != nil {
|
}
|
||||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
|
||||||
|
func applyReasoningEffort(body []byte, effort string) ([]byte, error) {
|
||||||
|
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
|
||||||
|
if errDeleteThinking != nil {
|
||||||
|
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
|
||||||
|
}
|
||||||
|
result, errSetEffort := sjson.SetBytes(result, "reasoning_effort", effort)
|
||||||
|
if errSetEffort != nil {
|
||||||
|
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", errSetEffort)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyDisabledThinking(body []byte) ([]byte, error) {
|
||||||
|
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
|
||||||
|
if errDeleteThinking != nil {
|
||||||
|
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
|
||||||
|
}
|
||||||
|
result, errDeleteEffort := sjson.DeleteBytes(result, "reasoning_effort")
|
||||||
|
if errDeleteEffort != nil {
|
||||||
|
return body, fmt.Errorf("kimi thinking: failed to clear reasoning_effort: %w", errDeleteEffort)
|
||||||
|
}
|
||||||
|
result, errSetType := sjson.SetBytes(result, "thinking.type", "disabled")
|
||||||
|
if errSetType != nil {
|
||||||
|
return body, fmt.Errorf("kimi thinking: failed to set thinking.type: %w", errSetType)
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|||||||
72
internal/thinking/provider/kimi/apply_test.go
Normal file
72
internal/thinking/provider/kimi/apply_test.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package kimi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApply_ModeNone_UsesDisabledThinking(t *testing.T) {
|
||||||
|
applier := NewApplier()
|
||||||
|
modelInfo := ®istry.ModelInfo{
|
||||||
|
ID: "kimi-k2.5",
|
||||||
|
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"kimi-k2.5","reasoning_effort":"none","thinking":{"type":"enabled","budget_tokens":2048}}`)
|
||||||
|
|
||||||
|
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
|
||||||
|
if errApply != nil {
|
||||||
|
t.Fatalf("Apply() error = %v", errApply)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
|
||||||
|
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
|
||||||
|
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "reasoning_effort").Exists() {
|
||||||
|
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApply_ModeLevel_UsesReasoningEffort(t *testing.T) {
|
||||||
|
applier := NewApplier()
|
||||||
|
modelInfo := ®istry.ModelInfo{
|
||||||
|
ID: "kimi-k2.5",
|
||||||
|
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"kimi-k2.5","thinking":{"type":"disabled"}}`)
|
||||||
|
|
||||||
|
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeLevel, Level: thinking.LevelHigh}, modelInfo)
|
||||||
|
if errApply != nil {
|
||||||
|
t.Fatalf("Apply() error = %v", errApply)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "reasoning_effort").String(); got != "high" {
|
||||||
|
t.Fatalf("reasoning_effort = %q, want %q, body=%s", got, "high", string(out))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "thinking").Exists() {
|
||||||
|
t.Fatalf("thinking should be removed when reasoning_effort is used, body=%s", string(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApply_UserDefinedModeNone_UsesDisabledThinking(t *testing.T) {
|
||||||
|
applier := NewApplier()
|
||||||
|
modelInfo := ®istry.ModelInfo{
|
||||||
|
ID: "custom-kimi-model",
|
||||||
|
UserDefined: true,
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"custom-kimi-model","reasoning_effort":"none"}`)
|
||||||
|
|
||||||
|
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
|
||||||
|
if errApply != nil {
|
||||||
|
t.Fatalf("Apply() error = %v", errApply)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
|
||||||
|
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "reasoning_effort").Exists() {
|
||||||
|
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,57 +6,12 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// validReasoningEffortLevels contains the standard values accepted by the
|
|
||||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
|
||||||
// auto) are NOT in this set and must be clamped before use.
|
|
||||||
var validReasoningEffortLevels = map[string]struct{}{
|
|
||||||
"none": {},
|
|
||||||
"low": {},
|
|
||||||
"medium": {},
|
|
||||||
"high": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
|
||||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
|
||||||
// mapped to the nearest standard equivalent.
|
|
||||||
//
|
|
||||||
// Mapping rules:
|
|
||||||
// - none / low / medium / high → returned as-is (already valid)
|
|
||||||
// - xhigh → "high" (nearest lower standard level)
|
|
||||||
// - minimal → "low" (nearest higher standard level)
|
|
||||||
// - auto → "medium" (reasonable default)
|
|
||||||
// - anything else → "medium" (safe default)
|
|
||||||
func clampReasoningEffort(level string) string {
|
|
||||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
|
||||||
return level
|
|
||||||
}
|
|
||||||
var clamped string
|
|
||||||
switch level {
|
|
||||||
case string(thinking.LevelXHigh):
|
|
||||||
clamped = string(thinking.LevelHigh)
|
|
||||||
case string(thinking.LevelMinimal):
|
|
||||||
clamped = string(thinking.LevelLow)
|
|
||||||
case string(thinking.LevelAuto):
|
|
||||||
clamped = string(thinking.LevelMedium)
|
|
||||||
default:
|
|
||||||
clamped = string(thinking.LevelMedium)
|
|
||||||
}
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"original": level,
|
|
||||||
"clamped": clamped,
|
|
||||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
|
||||||
return clamped
|
|
||||||
}
|
|
||||||
|
|
||||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||||
//
|
//
|
||||||
// OpenAI-specific behavior:
|
// OpenAI-specific behavior:
|
||||||
@@ -101,14 +56,14 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.Mode == thinking.ModeLevel {
|
if config.Mode == thinking.ModeLevel {
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
effort := ""
|
effort := ""
|
||||||
support := modelInfo.Thinking
|
support := modelInfo.Thinking
|
||||||
if config.Budget == 0 {
|
if config.Budget == 0 {
|
||||||
if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) {
|
if support.ZeroAllowed || thinking.HasLevel(support.Levels, string(thinking.LevelNone)) {
|
||||||
effort = string(thinking.LevelNone)
|
effort = string(thinking.LevelNone)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -122,7 +77,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,15 +112,6 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasLevel(levels []string, target string) bool {
|
|
||||||
for _, level := range levels {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(level), target) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -30,13 +30,18 @@ func StripThinkingConfig(body []byte, provider string) []byte {
|
|||||||
var paths []string
|
var paths []string
|
||||||
switch provider {
|
switch provider {
|
||||||
case "claude":
|
case "claude":
|
||||||
paths = []string{"thinking"}
|
paths = []string{"thinking", "output_config.effort"}
|
||||||
case "gemini":
|
case "gemini":
|
||||||
paths = []string{"generationConfig.thinkingConfig"}
|
paths = []string{"generationConfig.thinkingConfig"}
|
||||||
case "gemini-cli", "antigravity":
|
case "gemini-cli", "antigravity":
|
||||||
paths = []string{"request.generationConfig.thinkingConfig"}
|
paths = []string{"request.generationConfig.thinkingConfig"}
|
||||||
case "openai":
|
case "openai":
|
||||||
paths = []string{"reasoning_effort"}
|
paths = []string{"reasoning_effort"}
|
||||||
|
case "kimi":
|
||||||
|
paths = []string{
|
||||||
|
"reasoning_effort",
|
||||||
|
"thinking",
|
||||||
|
}
|
||||||
case "codex":
|
case "codex":
|
||||||
paths = []string{"reasoning.effort"}
|
paths = []string{"reasoning.effort"}
|
||||||
case "iflow":
|
case "iflow":
|
||||||
@@ -54,5 +59,12 @@ func StripThinkingConfig(body []byte, provider string) []byte {
|
|||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
result, _ = sjson.DeleteBytes(result, path)
|
result, _ = sjson.DeleteBytes(result, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Avoid leaving an empty output_config object for Claude when effort was the only field.
|
||||||
|
if provider == "claude" {
|
||||||
|
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
|
||||||
|
result, _ = sjson.DeleteBytes(result, "output_config")
|
||||||
|
}
|
||||||
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func ParseSpecialSuffix(rawSuffix string) (mode ThinkingMode, ok bool) {
|
|||||||
// ParseLevelSuffix attempts to parse a raw suffix as a discrete thinking level.
|
// ParseLevelSuffix attempts to parse a raw suffix as a discrete thinking level.
|
||||||
//
|
//
|
||||||
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as a level.
|
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as a level.
|
||||||
// Only discrete effort levels are valid: minimal, low, medium, high, xhigh.
|
// Only discrete effort levels are valid: minimal, low, medium, high, xhigh, max.
|
||||||
// Level matching is case-insensitive.
|
// Level matching is case-insensitive.
|
||||||
//
|
//
|
||||||
// Special values (none, auto) are NOT handled by this function; use ParseSpecialSuffix
|
// Special values (none, auto) are NOT handled by this function; use ParseSpecialSuffix
|
||||||
@@ -140,6 +140,8 @@ func ParseLevelSuffix(rawSuffix string) (level ThinkingLevel, ok bool) {
|
|||||||
return LevelHigh, true
|
return LevelHigh, true
|
||||||
case "xhigh":
|
case "xhigh":
|
||||||
return LevelXHigh, true
|
return LevelXHigh, true
|
||||||
|
case "max":
|
||||||
|
return LevelMax, true
|
||||||
default:
|
default:
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,9 @@ const (
|
|||||||
LevelHigh ThinkingLevel = "high"
|
LevelHigh ThinkingLevel = "high"
|
||||||
// LevelXHigh sets extra-high thinking effort
|
// LevelXHigh sets extra-high thinking effort
|
||||||
LevelXHigh ThinkingLevel = "xhigh"
|
LevelXHigh ThinkingLevel = "xhigh"
|
||||||
|
// LevelMax sets maximum thinking effort.
|
||||||
|
// This is currently used by Claude 4.6 adaptive thinking (opus supports "max").
|
||||||
|
LevelMax ThinkingLevel = "max"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ThinkingConfig represents a unified thinking configuration.
|
// ThinkingConfig represents a unified thinking configuration.
|
||||||
|
|||||||
@@ -53,7 +53,17 @@ func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, fromFo
|
|||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
allowClampUnsupported := isBudgetBasedProvider(fromFormat) && isLevelBasedProvider(toFormat)
|
// allowClampUnsupported determines whether to clamp unsupported levels instead of returning an error.
|
||||||
|
// This applies when crossing provider families (e.g., openai→gemini, claude→gemini) and the target
|
||||||
|
// model supports discrete levels. Same-family conversions require strict validation.
|
||||||
|
toCapability := detectModelCapability(modelInfo)
|
||||||
|
toHasLevelSupport := toCapability == CapabilityLevelOnly || toCapability == CapabilityHybrid
|
||||||
|
allowClampUnsupported := toHasLevelSupport && !isSameProviderFamily(fromFormat, toFormat)
|
||||||
|
|
||||||
|
// strictBudget determines whether to enforce strict budget range validation.
|
||||||
|
// This applies when: (1) config comes from request body (not suffix), (2) source format is known,
|
||||||
|
// and (3) source and target are in the same provider family. Cross-family or suffix-based configs
|
||||||
|
// are clamped instead of rejected to improve interoperability.
|
||||||
strictBudget := !fromSuffix && fromFormat != "" && isSameProviderFamily(fromFormat, toFormat)
|
strictBudget := !fromSuffix && fromFormat != "" && isSameProviderFamily(fromFormat, toFormat)
|
||||||
budgetDerivedFromLevel := false
|
budgetDerivedFromLevel := false
|
||||||
|
|
||||||
@@ -201,7 +211,7 @@ func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupp
|
|||||||
}
|
}
|
||||||
|
|
||||||
// standardLevelOrder defines the canonical ordering of thinking levels from lowest to highest.
|
// standardLevelOrder defines the canonical ordering of thinking levels from lowest to highest.
|
||||||
var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh}
|
var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh, LevelMax}
|
||||||
|
|
||||||
// clampLevel clamps the given level to the nearest supported level.
|
// clampLevel clamps the given level to the nearest supported level.
|
||||||
// On tie, prefers the lower level.
|
// On tie, prefers the lower level.
|
||||||
@@ -325,7 +335,9 @@ func normalizeLevels(levels []string) []string {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func isBudgetBasedProvider(provider string) bool {
|
// isBudgetCapableProvider returns true if the provider supports budget-based thinking.
|
||||||
|
// These providers may also support level-based thinking (hybrid models).
|
||||||
|
func isBudgetCapableProvider(provider string) bool {
|
||||||
switch provider {
|
switch provider {
|
||||||
case "gemini", "gemini-cli", "antigravity", "claude":
|
case "gemini", "gemini-cli", "antigravity", "claude":
|
||||||
return true
|
return true
|
||||||
@@ -334,15 +346,6 @@ func isBudgetBasedProvider(provider string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLevelBasedProvider(provider string) bool {
|
|
||||||
switch provider {
|
|
||||||
case "openai", "openai-response", "codex":
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isGeminiFamily(provider string) bool {
|
func isGeminiFamily(provider string) bool {
|
||||||
switch provider {
|
switch provider {
|
||||||
case "gemini", "gemini-cli", "antigravity":
|
case "gemini", "gemini-cli", "antigravity":
|
||||||
@@ -352,11 +355,21 @@ func isGeminiFamily(provider string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOpenAIFamily(provider string) bool {
|
||||||
|
switch provider {
|
||||||
|
case "openai", "openai-response", "codex":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func isSameProviderFamily(from, to string) bool {
|
func isSameProviderFamily(from, to string) bool {
|
||||||
if from == to {
|
if from == to {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return isGeminiFamily(from) && isGeminiFamily(to)
|
return (isGeminiFamily(from) && isGeminiFamily(to)) ||
|
||||||
|
(isOpenAIFamily(from) && isOpenAIFamily(to))
|
||||||
}
|
}
|
||||||
|
|
||||||
func abs(x int) int {
|
func abs(x int) int {
|
||||||
|
|||||||
@@ -223,14 +223,65 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||||
} else if functionResponseResult.IsArray() {
|
} else if functionResponseResult.IsArray() {
|
||||||
frResults := functionResponseResult.Array()
|
frResults := functionResponseResult.Array()
|
||||||
if len(frResults) == 1 {
|
nonImageCount := 0
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
lastNonImageRaw := ""
|
||||||
|
filteredJSON := "[]"
|
||||||
|
imagePartsJSON := "[]"
|
||||||
|
for _, fr := range frResults {
|
||||||
|
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||||
|
inlineDataJSON := `{}`
|
||||||
|
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
|
}
|
||||||
|
if data := fr.Get("source.data").String(); data != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePartJSON := `{}`
|
||||||
|
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||||
|
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nonImageCount++
|
||||||
|
lastNonImageRaw = fr.Raw
|
||||||
|
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nonImageCount == 1 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
||||||
|
} else if nonImageCount > 1 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
||||||
} else {
|
} else {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Place image data inside functionResponse.parts as inlineData
|
||||||
|
// instead of as sibling parts in the outer content, to avoid
|
||||||
|
// base64 data bloating the text context.
|
||||||
|
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if functionResponseResult.IsObject() {
|
} else if functionResponseResult.IsObject() {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||||
|
inlineDataJSON := `{}`
|
||||||
|
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
|
}
|
||||||
|
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||||
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePartJSON := `{}`
|
||||||
|
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
||||||
|
imagePartsJSON := "[]"
|
||||||
|
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
||||||
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||||
|
} else {
|
||||||
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
|
}
|
||||||
} else if functionResponseResult.Raw != "" {
|
} else if functionResponseResult.Raw != "" {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
} else {
|
} else {
|
||||||
@@ -248,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if sourceResult.Get("type").String() == "base64" {
|
if sourceResult.Get("type").String() == "base64" {
|
||||||
inlineDataJSON := `{}`
|
inlineDataJSON := `{}`
|
||||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
||||||
}
|
}
|
||||||
if data := sourceResult.Get("data").String(); data != "" {
|
if data := sourceResult.Get("data").String(); data != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||||
@@ -349,7 +400,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
hasTools := toolDeclCount > 0
|
hasTools := toolDeclCount > 0
|
||||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||||
thinkingType := thinkingResult.Get("type").String()
|
thinkingType := thinkingResult.Get("type").String()
|
||||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
|
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive" || thinkingType == "auto")
|
||||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||||
|
|
||||||
if hasTools && hasThinking && isClaudeThinking {
|
if hasTools && hasThinking && isClaudeThinking {
|
||||||
@@ -380,6 +431,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tool_choice
|
||||||
|
toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice")
|
||||||
|
if toolChoiceResult.Exists() {
|
||||||
|
toolChoiceType := ""
|
||||||
|
toolChoiceName := ""
|
||||||
|
if toolChoiceResult.IsObject() {
|
||||||
|
toolChoiceType = toolChoiceResult.Get("type").String()
|
||||||
|
toolChoiceName = toolChoiceResult.Get("name").String()
|
||||||
|
} else if toolChoiceResult.Type == gjson.String {
|
||||||
|
toolChoiceType = toolChoiceResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch toolChoiceType {
|
||||||
|
case "auto":
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||||
|
case "none":
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||||
|
case "any":
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||||
|
case "tool":
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||||
|
if toolChoiceName != "" {
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||||
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||||
switch t.Get("type").String() {
|
switch t.Get("type").String() {
|
||||||
@@ -389,10 +467,23 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
case "adaptive":
|
case "adaptive", "auto":
|
||||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
// For adaptive thinking:
|
||||||
// to model-specific max capability.
|
// - If output_config.effort is explicitly present, pass through as thinkingLevel.
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
// - Otherwise, treat it as "enabled with target-model maximum" and emit high.
|
||||||
|
// ApplyThinking handles clamping to target model's supported levels.
|
||||||
|
effort := ""
|
||||||
|
if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String {
|
||||||
|
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||||
|
}
|
||||||
|
if effort != "" {
|
||||||
|
if effort == "max" {
|
||||||
|
effort = "high"
|
||||||
|
}
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||||
|
} else {
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||||
|
}
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -193,6 +193,42 @@ func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolChoice_SpecificTool(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gemini-3-flash-preview",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "hi"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "json",
|
||||||
|
"description": "A JSON tool",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": {"type": "tool", "name": "json"}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("gemini-3-flash-preview", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if got := gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" {
|
||||||
|
t.Fatalf("Expected toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got)
|
||||||
|
}
|
||||||
|
allowed := gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array()
|
||||||
|
if len(allowed) != 1 || allowed[0].String() != "json" {
|
||||||
|
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-3-5-sonnet-20240620",
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
@@ -413,8 +449,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
|
|||||||
if !inlineData.Exists() {
|
if !inlineData.Exists() {
|
||||||
t.Error("inlineData should exist")
|
t.Error("inlineData should exist")
|
||||||
}
|
}
|
||||||
if inlineData.Get("mime_type").String() != "image/png" {
|
if inlineData.Get("mimeType").String() != "image/png" {
|
||||||
t.Error("mime_type mismatch")
|
t.Error("mimeType mismatch")
|
||||||
}
|
}
|
||||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||||
t.Error("data mismatch")
|
t.Error("data mismatch")
|
||||||
@@ -740,6 +776,429 @@ func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
|
||||||
|
// tool_result with array content containing text + image should place
|
||||||
|
// image data inside functionResponse.parts as inlineData, not as a
|
||||||
|
// sibling part in the outer content (to avoid base64 context bloat).
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Read-123-456",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "File content here"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be inside functionResponse.parts, not as outer sibling part
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text content should be in response.result
|
||||||
|
resultText := funcResp.Get("response.result.text").String()
|
||||||
|
if resultText != "File content here" {
|
||||||
|
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be in functionResponse.parts[0].inlineData
|
||||||
|
inlineData := funcResp.Get("parts.0.inlineData")
|
||||||
|
if !inlineData.Exists() {
|
||||||
|
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||||
|
}
|
||||||
|
if inlineData.Get("mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||||
|
t.Error("data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should NOT be in outer parts (only functionResponse part should exist)
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||||
|
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||||
|
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
|
||||||
|
// tool_result with single image object as content should place
|
||||||
|
// image data inside functionResponse.parts, not as outer sibling part.
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Read-789-012",
|
||||||
|
"content": {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": "/9j/4AAQSkZJRgABAQ=="
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// response.result should be empty (image only)
|
||||||
|
if funcResp.Get("response.result").String() != "" {
|
||||||
|
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should be in functionResponse.parts[0].inlineData
|
||||||
|
inlineData := funcResp.Get("parts.0.inlineData")
|
||||||
|
if !inlineData.Exists() {
|
||||||
|
t.Fatal("functionResponse.parts[0].inlineData should exist")
|
||||||
|
}
|
||||||
|
if inlineData.Get("mimeType").String() != "image/jpeg" {
|
||||||
|
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image should NOT be in outer parts
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
|
||||||
|
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
|
||||||
|
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
|
||||||
|
// tool_result with array content: 2 text items + 2 images
|
||||||
|
// All images go into functionResponse.parts, texts into response.result array
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Multi-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "First text"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Second text"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple text items => response.result is an array
|
||||||
|
resultArr := funcResp.Get("response.result")
|
||||||
|
if !resultArr.IsArray() {
|
||||||
|
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
|
||||||
|
}
|
||||||
|
results := resultArr.Array()
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both images should be in functionResponse.parts
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 2 {
|
||||||
|
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||||
|
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
|
||||||
|
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
|
||||||
|
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only 1 outer part (the functionResponse itself)
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||||
|
if len(outerParts) != 1 {
|
||||||
|
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
|
||||||
|
// tool_result with only images (no text) — response.result should be empty string
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "ImgOnly-001",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// No text => response.result should be empty string
|
||||||
|
if funcResp.Get("response.result").String() != "" {
|
||||||
|
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both images in functionResponse.parts
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 2 {
|
||||||
|
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Error("first image mimeType mismatch")
|
||||||
|
}
|
||||||
|
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
|
||||||
|
t.Error("second image mimeType mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only 1 outer part
|
||||||
|
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||||
|
if len(outerParts) != 1 {
|
||||||
|
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
|
||||||
|
// image with source.type != "base64" should be treated as non-image (falls through)
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NotB64-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "some output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "url", "url": "https://example.com/img.png"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-base64 image is treated as non-image, so it goes into the filtered results
|
||||||
|
// along with the text item. Since there are 2 non-image items, result is array.
|
||||||
|
resultArr := funcResp.Get("response.result")
|
||||||
|
if !resultArr.IsArray() {
|
||||||
|
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
|
||||||
|
}
|
||||||
|
results := resultArr.Array()
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("Expected 2 result items, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// No functionResponse.parts (no base64 images collected)
|
||||||
|
if funcResp.Get("parts").Exists() {
|
||||||
|
t.Error("functionResponse.parts should NOT exist when no base64 images")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
|
||||||
|
// image with source.type=base64 but missing data field
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NoData-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "media_type": "image/png"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The image is still classified as base64 image (type check passes),
|
||||||
|
// but data field is missing => inlineData has mimeType but no data
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Error("mimeType should still be set")
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").Exists() {
|
||||||
|
t.Error("data should not exist when source.data is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
|
||||||
|
// image with source.type=base64 but missing media_type field
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "NoMime-001",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "output"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {"type": "base64", "data": "AAAA"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The image is still classified as base64 image,
|
||||||
|
// but media_type is missing => inlineData has data but no mimeType
|
||||||
|
imgParts := funcResp.Get("parts").Array()
|
||||||
|
if len(imgParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.mimeType").Exists() {
|
||||||
|
t.Error("mimeType should not exist when media_type is missing")
|
||||||
|
}
|
||||||
|
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
|
||||||
|
t.Error("data should still be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||||
// When tools + thinking but no system instruction, should create one with hint
|
// When tools + thinking but no system instruction, should create one with hint
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
@@ -776,3 +1235,64 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t
|
|||||||
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
|
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
effort string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"low", "low", "low"},
|
||||||
|
{"medium", "medium", "medium"},
|
||||||
|
{"high", "high", "high"},
|
||||||
|
{"max", "max", "high"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||||
|
"thinking": {"type": "adaptive"},
|
||||||
|
"output_config": {"effort": "` + tt.effort + `"}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
||||||
|
if !thinkingConfig.Exists() {
|
||||||
|
t.Fatal("thinkingConfig should exist for adaptive thinking")
|
||||||
|
}
|
||||||
|
if thinkingConfig.Get("thinkingLevel").String() != tt.expected {
|
||||||
|
t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String())
|
||||||
|
}
|
||||||
|
if !thinkingConfig.Get("includeThoughts").Bool() {
|
||||||
|
t.Error("includeThoughts should be true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||||
|
"thinking": {"type": "adaptive"}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
||||||
|
if !thinkingConfig.Exists() {
|
||||||
|
t.Fatal("thinkingConfig should exist for adaptive thinking without effort")
|
||||||
|
}
|
||||||
|
if thinkingConfig.Get("thinkingLevel").String() != "high" {
|
||||||
|
t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String())
|
||||||
|
}
|
||||||
|
if !thinkingConfig.Get("includeThoughts").Bool() {
|
||||||
|
t.Error("includeThoughts should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
||||||
|
// When functionResponse contains a "parts" field with inlineData (from Claude
|
||||||
|
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
|
||||||
|
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
|
||||||
|
// so extra fields like "parts" survive the pipeline.
|
||||||
|
input := `{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionCall": {"name": "screenshot", "args": {}}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionResponse": {
|
||||||
|
"id": "tool-001",
|
||||||
|
"name": "screenshot",
|
||||||
|
"response": {"result": "Screenshot taken"},
|
||||||
|
"parts": [
|
||||||
|
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := fixCLIToolResponse(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the function response content (role=function)
|
||||||
|
contents := gjson.Get(result, "request.contents").Array()
|
||||||
|
var funcContent gjson.Result
|
||||||
|
for _, c := range contents {
|
||||||
|
if c.Get("role").String() == "function" {
|
||||||
|
funcContent = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !funcContent.Exists() {
|
||||||
|
t.Fatal("function role content should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The functionResponse should be preserved with its parts field
|
||||||
|
funcResp := funcContent.Get("parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the parts field with inlineData is preserved
|
||||||
|
inlineParts := funcResp.Get("parts").Array()
|
||||||
|
if len(inlineParts) != 1 {
|
||||||
|
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
|
||||||
|
}
|
||||||
|
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
|
||||||
|
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
|
||||||
|
}
|
||||||
|
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
|
||||||
|
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify response.result is also preserved
|
||||||
|
if funcResp.Get("response.result").String() != "Screenshot taken" {
|
||||||
|
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -34,6 +34,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
// Model
|
// Model
|
||||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||||
|
|
||||||
|
// Let user-provided generationConfig pass through
|
||||||
|
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
|
||||||
|
out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw))
|
||||||
|
}
|
||||||
|
|
||||||
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
|
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
|
||||||
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
|
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
|
||||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||||
@@ -187,7 +192,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||||
mime := pieces[0]
|
mime := pieces[0]
|
||||||
data := pieces[1][7:]
|
data := pieces[1][7:]
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||||
p++
|
p++
|
||||||
@@ -201,7 +206,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
ext = sp[len(sp)-1]
|
ext = sp[len(sp)-1]
|
||||||
}
|
}
|
||||||
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
|
||||||
p++
|
p++
|
||||||
} else {
|
} else {
|
||||||
@@ -235,7 +240,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||||
mime := pieces[0]
|
mime := pieces[0]
|
||||||
data := pieces[1][7:]
|
data := pieces[1][7:]
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||||
p++
|
p++
|
||||||
|
|||||||
@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -115,24 +116,47 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
// Include thoughts configuration for reasoning process visibility
|
// Include thoughts configuration for reasoning process visibility
|
||||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||||
|
mi := registry.LookupModelInfo(modelName, "claude")
|
||||||
|
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
|
||||||
|
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
|
||||||
|
|
||||||
|
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
|
||||||
|
// validation errors since validate treats same-provider unsupported levels as errors.
|
||||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||||
if !thinkingLevel.Exists() {
|
if !thinkingLevel.Exists() {
|
||||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||||
}
|
}
|
||||||
if thinkingLevel.Exists() {
|
if thinkingLevel.Exists() {
|
||||||
level := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
level := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||||
switch level {
|
if supportsAdaptive {
|
||||||
case "":
|
switch level {
|
||||||
case "none":
|
case "":
|
||||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
case "none":
|
||||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||||
case "auto":
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
out, _ = sjson.Delete(out, "output_config.effort")
|
||||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
default:
|
||||||
default:
|
if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok {
|
||||||
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
|
level = mapped
|
||||||
|
}
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
out, _ = sjson.Set(out, "output_config.effort", level)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch level {
|
||||||
|
case "":
|
||||||
|
case "none":
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
case "auto":
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
default:
|
||||||
|
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
|
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -142,16 +166,35 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
}
|
}
|
||||||
if thinkingBudget.Exists() {
|
if thinkingBudget.Exists() {
|
||||||
budget := int(thinkingBudget.Int())
|
budget := int(thinkingBudget.Int())
|
||||||
switch budget {
|
if supportsAdaptive {
|
||||||
case 0:
|
switch budget {
|
||||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
case 0:
|
||||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||||
case -1:
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
out, _ = sjson.Delete(out, "output_config.effort")
|
||||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
default:
|
||||||
default:
|
level, ok := thinking.ConvertBudgetToLevel(budget)
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
if ok {
|
||||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM {
|
||||||
|
level = mapped
|
||||||
|
}
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
out, _ = sjson.Set(out, "output_config.effort", level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch budget {
|
||||||
|
case 0:
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
case -1:
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
default:
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
|
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
@@ -68,17 +69,45 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
if v := root.Get("reasoning_effort"); v.Exists() {
|
if v := root.Get("reasoning_effort"); v.Exists() {
|
||||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||||
if effort != "" {
|
if effort != "" {
|
||||||
budget, ok := thinking.ConvertLevelToBudget(effort)
|
mi := registry.LookupModelInfo(modelName, "claude")
|
||||||
if ok {
|
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
|
||||||
switch budget {
|
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
|
||||||
case 0:
|
|
||||||
|
// Claude 4.6 supports adaptive thinking with output_config.effort.
|
||||||
|
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
|
||||||
|
// validation errors since validate treats same-provider unsupported levels as errors.
|
||||||
|
if supportsAdaptive {
|
||||||
|
switch effort {
|
||||||
|
case "none":
|
||||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||||
case -1:
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
out, _ = sjson.Delete(out, "output_config.effort")
|
||||||
|
case "auto":
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
out, _ = sjson.Delete(out, "output_config.effort")
|
||||||
default:
|
default:
|
||||||
if budget > 0 {
|
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
|
||||||
|
effort = mapped
|
||||||
|
}
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
out, _ = sjson.Set(out, "output_config.effort", effort)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Legacy/manual thinking (budget_tokens).
|
||||||
|
budget, ok := thinking.ConvertLevelToBudget(effort)
|
||||||
|
if ok {
|
||||||
|
switch budget {
|
||||||
|
case 0:
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||||
|
case -1:
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
default:
|
||||||
|
if budget > 0 {
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
|
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,6 +228,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "file":
|
||||||
|
fileData := part.Get("file.file_data").String()
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
semicolonIdx := strings.Index(fileData, ";")
|
||||||
|
commaIdx := strings.Index(fileData, ",")
|
||||||
|
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||||
|
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||||
|
data := fileData[commaIdx+1:]
|
||||||
|
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
@@ -56,17 +57,45 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
if v := root.Get("reasoning.effort"); v.Exists() {
|
if v := root.Get("reasoning.effort"); v.Exists() {
|
||||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||||
if effort != "" {
|
if effort != "" {
|
||||||
budget, ok := thinking.ConvertLevelToBudget(effort)
|
mi := registry.LookupModelInfo(modelName, "claude")
|
||||||
if ok {
|
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
|
||||||
switch budget {
|
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
|
||||||
case 0:
|
|
||||||
|
// Claude 4.6 supports adaptive thinking with output_config.effort.
|
||||||
|
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
|
||||||
|
// validation errors since validate treats same-provider unsupported levels as errors.
|
||||||
|
if supportsAdaptive {
|
||||||
|
switch effort {
|
||||||
|
case "none":
|
||||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||||
case -1:
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
out, _ = sjson.Delete(out, "output_config.effort")
|
||||||
|
case "auto":
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
out, _ = sjson.Delete(out, "output_config.effort")
|
||||||
default:
|
default:
|
||||||
if budget > 0 {
|
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
|
||||||
|
effort = mapped
|
||||||
|
}
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "adaptive")
|
||||||
|
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||||
|
out, _ = sjson.Set(out, "output_config.effort", effort)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Legacy/manual thinking (budget_tokens).
|
||||||
|
budget, ok := thinking.ConvertLevelToBudget(effort)
|
||||||
|
if ok {
|
||||||
|
switch budget {
|
||||||
|
case 0:
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||||
|
case -1:
|
||||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
default:
|
||||||
|
if budget > 0 {
|
||||||
|
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||||
|
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -155,6 +184,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
var textAggregate strings.Builder
|
var textAggregate strings.Builder
|
||||||
var partsJSON []string
|
var partsJSON []string
|
||||||
hasImage := false
|
hasImage := false
|
||||||
|
hasFile := false
|
||||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||||
parts.ForEach(func(_, part gjson.Result) bool {
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
ptype := part.Get("type").String()
|
ptype := part.Get("type").String()
|
||||||
@@ -207,6 +237,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
hasImage = true
|
hasImage = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "input_file":
|
||||||
|
fileData := part.Get("file_data").String()
|
||||||
|
if fileData != "" {
|
||||||
|
mediaType := "application/octet-stream"
|
||||||
|
data := fileData
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
trimmed := strings.TrimPrefix(fileData, "data:")
|
||||||
|
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||||
|
if len(mediaAndData) == 2 {
|
||||||
|
if mediaAndData[0] != "" {
|
||||||
|
mediaType = mediaAndData[0]
|
||||||
|
}
|
||||||
|
data = mediaAndData[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||||
|
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||||
|
partsJSON = append(partsJSON, contentPart)
|
||||||
|
if role == "" {
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
hasFile = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -228,7 +282,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
if len(partsJSON) > 0 {
|
if len(partsJSON) > 0 {
|
||||||
msg := `{"role":"","content":[]}`
|
msg := `{"role":"","content":[]}`
|
||||||
msg, _ = sjson.Set(msg, "role", role)
|
msg, _ = sjson.Set(msg, "role", role)
|
||||||
if len(partsJSON) == 1 && !hasImage {
|
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||||
// Preserve legacy behavior for single text content
|
// Preserve legacy behavior for single text content
|
||||||
msg, _ = sjson.Delete(msg, "content")
|
msg, _ = sjson.Delete(msg, "content")
|
||||||
textPart := gjson.Parse(partsJSON[0])
|
textPart := gjson.Parse(partsJSON[0])
|
||||||
|
|||||||
@@ -46,15 +46,23 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
if systemsResult.IsArray() {
|
if systemsResult.IsArray() {
|
||||||
systemResults := systemsResult.Array()
|
systemResults := systemsResult.Array()
|
||||||
message := `{"type":"message","role":"developer","content":[]}`
|
message := `{"type":"message","role":"developer","content":[]}`
|
||||||
|
contentIndex := 0
|
||||||
for i := 0; i < len(systemResults); i++ {
|
for i := 0; i < len(systemResults); i++ {
|
||||||
systemResult := systemResults[i]
|
systemResult := systemResults[i]
|
||||||
systemTypeResult := systemResult.Get("type")
|
systemTypeResult := systemResult.Get("type")
|
||||||
if systemTypeResult.String() == "text" {
|
if systemTypeResult.String() == "text" {
|
||||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", i), "input_text")
|
text := systemResult.Get("text").String()
|
||||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", i), systemResult.Get("text").String())
|
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||||
|
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||||
|
contentIndex++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
if contentIndex > 0 {
|
||||||
|
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process messages and transform their contents to appropriate formats.
|
// Process messages and transform their contents to appropriate formats.
|
||||||
@@ -152,7 +160,51 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
flushMessage()
|
flushMessage()
|
||||||
functionCallOutputMessage := `{"type":"function_call_output"}`
|
functionCallOutputMessage := `{"type":"function_call_output"}`
|
||||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
||||||
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
|
||||||
|
contentResult := messageContentResult.Get("content")
|
||||||
|
if contentResult.IsArray() {
|
||||||
|
toolResultContentIndex := 0
|
||||||
|
toolResultContent := `[]`
|
||||||
|
contentResults := contentResult.Array()
|
||||||
|
for k := 0; k < len(contentResults); k++ {
|
||||||
|
toolResultContentType := contentResults[k].Get("type").String()
|
||||||
|
if toolResultContentType == "image" {
|
||||||
|
sourceResult := contentResults[k].Get("source")
|
||||||
|
if sourceResult.Exists() {
|
||||||
|
data := sourceResult.Get("data").String()
|
||||||
|
if data == "" {
|
||||||
|
data = sourceResult.Get("base64").String()
|
||||||
|
}
|
||||||
|
if data != "" {
|
||||||
|
mediaType := sourceResult.Get("media_type").String()
|
||||||
|
if mediaType == "" {
|
||||||
|
mediaType = sourceResult.Get("mime_type").String()
|
||||||
|
}
|
||||||
|
if mediaType == "" {
|
||||||
|
mediaType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
|
||||||
|
|
||||||
|
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
|
||||||
|
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
|
||||||
|
toolResultContentIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if toolResultContentType == "text" {
|
||||||
|
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
|
||||||
|
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
|
||||||
|
toolResultContentIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolResultContent != `[]` {
|
||||||
|
functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent)
|
||||||
|
} else {
|
||||||
|
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||||
|
}
|
||||||
|
|
||||||
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
|
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -203,6 +255,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw))
|
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw))
|
||||||
tool, _ = sjson.Delete(tool, "input_schema")
|
tool, _ = sjson.Delete(tool, "input_schema")
|
||||||
tool, _ = sjson.Delete(tool, "parameters.$schema")
|
tool, _ = sjson.Delete(tool, "parameters.$schema")
|
||||||
|
tool, _ = sjson.Delete(tool, "cache_control")
|
||||||
|
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||||
tool, _ = sjson.Set(tool, "strict", false)
|
tool, _ = sjson.Set(tool, "strict", false)
|
||||||
template, _ = sjson.SetRaw(template, "tools.-1", tool)
|
template, _ = sjson.SetRaw(template, "tools.-1", tool)
|
||||||
}
|
}
|
||||||
@@ -222,10 +276,18 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
reasoningEffort = effort
|
reasoningEffort = effort
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "adaptive":
|
case "adaptive", "auto":
|
||||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
// Adaptive thinking can carry an explicit effort in output_config.effort (Claude 4.6).
|
||||||
// and let ApplyThinking normalize per target model capability.
|
// Pass through directly; ApplyThinking handles clamping to target model's levels.
|
||||||
reasoningEffort = string(thinking.LevelXHigh)
|
effort := ""
|
||||||
|
if v := rootResult.Get("output_config.effort"); v.Exists() && v.Type == gjson.String {
|
||||||
|
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||||
|
}
|
||||||
|
if effort != "" {
|
||||||
|
reasoningEffort = effort
|
||||||
|
} else {
|
||||||
|
reasoningEffort = string(thinking.LevelXHigh)
|
||||||
|
}
|
||||||
case "disabled":
|
case "disabled":
|
||||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||||
reasoningEffort = effort
|
reasoningEffort = effort
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ var (
|
|||||||
|
|
||||||
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
||||||
type ConvertCodexResponseToClaudeParams struct {
|
type ConvertCodexResponseToClaudeParams struct {
|
||||||
HasToolCall bool
|
HasToolCall bool
|
||||||
BlockIndex int
|
BlockIndex int
|
||||||
HasReceivedArgumentsDelta bool
|
HasReceivedArgumentsDelta bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
}
|
}
|
||||||
case "file":
|
case "file":
|
||||||
// Files are not specified in examples; skip for now
|
if role == "user" {
|
||||||
|
fileData := it.Get("file.file_data").String()
|
||||||
|
filename := it.Get("file.filename").String()
|
||||||
|
if fileData != "" {
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", "input_file")
|
||||||
|
part, _ = sjson.Set(part, "file_data", fileData)
|
||||||
|
if filename != "" {
|
||||||
|
part, _ = sjson.Set(part, "filename", filename)
|
||||||
|
}
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,8 +74,13 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set the model version.
|
// Extract and set the model version.
|
||||||
|
cachedModel := (*param).(*ConvertCliToOpenAIParams).Model
|
||||||
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
|
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "model", modelResult.String())
|
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||||
|
} else if cachedModel != "" {
|
||||||
|
template, _ = sjson.Set(template, "model", cachedModel)
|
||||||
|
} else if modelName != "" {
|
||||||
|
template, _ = sjson.Set(template, "model", modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
|
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
package chat_completions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
modelName := "gpt-5.3-codex"
|
||||||
|
|
||||||
|
out := ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.created","response":{"id":"resp_123","created_at":1700000000,"model":"gpt-5.3-codex"}}`), ¶m)
|
||||||
|
if len(out) != 0 {
|
||||||
|
t.Fatalf("expected no output for response.created, got %d chunks", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
out = ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.output_text.delta","delta":"hello"}`), ¶m)
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
gotModel := gjson.Get(out[0], "model").String()
|
||||||
|
if gotModel != modelName {
|
||||||
|
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
modelName := "gpt-5.3-codex"
|
||||||
|
|
||||||
|
out := ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.output_text.delta","delta":"hello"}`), ¶m)
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
gotModel := gjson.Get(out[0], "model").String()
|
||||||
|
if gotModel != modelName {
|
||||||
|
t.Fatalf("expected model %q, got %q", modelName, gotModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,7 +25,9 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
// rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||||
|
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||||
|
|
||||||
// Delete the user field as it is not supported by the Codex upstream.
|
// Delete the user field as it is not supported by the Codex upstream.
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||||
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
return rawJSON
|
return rawJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
|
||||||
|
// for Codex upstream compatibility.
|
||||||
|
//
|
||||||
|
// Codex /responses currently rejects context_management with:
|
||||||
|
// {"detail":"Unsupported parameter: context_management"}.
|
||||||
|
//
|
||||||
|
// Compatibility strategy:
|
||||||
|
// 1) Remove context_management before forwarding to Codex upstream.
|
||||||
|
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
|
||||||
|
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
||||||
// with role "system" to role "developer". This is necessary because Codex API does not
|
// with role "system" to role "developer". This is necessary because Codex API does not
|
||||||
// accept "system" role in the input array.
|
// accept "system" role in the input array.
|
||||||
|
|||||||
@@ -264,19 +264,57 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserFieldDeletion(t *testing.T) {
|
func TestUserFieldDeletion(t *testing.T) {
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "gpt-5.2",
|
"model": "gpt-5.2",
|
||||||
"user": "test-user",
|
"user": "test-user",
|
||||||
"input": [{"role": "user", "content": "Hello"}]
|
"input": [{"role": "user", "content": "Hello"}]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Verify user field is deleted
|
// Verify user field is deleted
|
||||||
userField := gjson.Get(outputStr, "user")
|
userField := gjson.Get(outputStr, "user")
|
||||||
if userField.Exists() {
|
if userField.Exists() {
|
||||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextManagementCompactionCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"context_management": [
|
||||||
|
{
|
||||||
|
"type": "compaction",
|
||||||
|
"compact_threshold": 12000
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "context_management").Exists() {
|
||||||
|
t.Fatalf("context_management should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"truncation": "disabled",
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,24 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
||||||
// to OpenAI Responses SSE events (response.*).
|
// to OpenAI Responses SSE events (response.*).
|
||||||
|
|
||||||
func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
|
||||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||||
if typeResult := gjson.GetBytes(rawJSON, "type"); typeResult.Exists() {
|
|
||||||
typeStr := typeResult.String()
|
|
||||||
if typeStr == "response.created" || typeStr == "response.in_progress" || typeStr == "response.completed" {
|
|
||||||
if gjson.GetBytes(rawJSON, "response.instructions").Exists() {
|
|
||||||
instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String()
|
|
||||||
rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", instructions)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out := fmt.Sprintf("data: %s", string(rawJSON))
|
out := fmt.Sprintf("data: %s", string(rawJSON))
|
||||||
return []string{out}
|
return []string{out}
|
||||||
}
|
}
|
||||||
@@ -32,17 +22,12 @@ func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string
|
|||||||
|
|
||||||
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
|
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
|
||||||
// from a non-streaming OpenAI Chat Completions response.
|
// from a non-streaming OpenAI Chat Completions response.
|
||||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
|
||||||
rootResult := gjson.ParseBytes(rawJSON)
|
rootResult := gjson.ParseBytes(rawJSON)
|
||||||
// Verify this is a response.completed event
|
// Verify this is a response.completed event
|
||||||
if rootResult.Get("type").String() != "response.completed" {
|
if rootResult.Get("type").String() != "response.completed" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
responseResult := rootResult.Get("response")
|
responseResult := rootResult.Get("response")
|
||||||
template := responseResult.Raw
|
return responseResult.Raw
|
||||||
if responseResult.Get("instructions").Exists() {
|
|
||||||
instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String()
|
|
||||||
template, _ = sjson.Set(template, "instructions", instructions)
|
|
||||||
}
|
|
||||||
return template
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -156,6 +156,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
|||||||
tool, _ = sjson.Delete(tool, "input_examples")
|
tool, _ = sjson.Delete(tool, "input_examples")
|
||||||
tool, _ = sjson.Delete(tool, "type")
|
tool, _ = sjson.Delete(tool, "type")
|
||||||
tool, _ = sjson.Delete(tool, "cache_control")
|
tool, _ = sjson.Delete(tool, "cache_control")
|
||||||
|
tool, _ = sjson.Delete(tool, "defer_loading")
|
||||||
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
|
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
|
||||||
if !hasTools {
|
if !hasTools {
|
||||||
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
|
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
|
||||||
@@ -171,7 +172,35 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
// tool_choice
|
||||||
|
toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice")
|
||||||
|
if toolChoiceResult.Exists() {
|
||||||
|
toolChoiceType := ""
|
||||||
|
toolChoiceName := ""
|
||||||
|
if toolChoiceResult.IsObject() {
|
||||||
|
toolChoiceType = toolChoiceResult.Get("type").String()
|
||||||
|
toolChoiceName = toolChoiceResult.Get("name").String()
|
||||||
|
} else if toolChoiceResult.Type == gjson.String {
|
||||||
|
toolChoiceType = toolChoiceResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch toolChoiceType {
|
||||||
|
case "auto":
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||||
|
case "none":
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||||
|
case "any":
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||||
|
case "tool":
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||||
|
if toolChoiceName != "" {
|
||||||
|
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map Anthropic thinking -> Gemini CLI thinkingConfig when enabled
|
||||||
|
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||||
switch t.Get("type").String() {
|
switch t.Get("type").String() {
|
||||||
case "enabled":
|
case "enabled":
|
||||||
@@ -180,10 +209,20 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
|||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
case "adaptive":
|
case "adaptive", "auto":
|
||||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
// For adaptive thinking:
|
||||||
// to model-specific max capability.
|
// - If output_config.effort is explicitly present, pass through as thinkingLevel.
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
// - Otherwise, treat it as "enabled with target-model maximum" and emit high.
|
||||||
|
// ApplyThinking handles clamping to target model's supported levels.
|
||||||
|
effort := ""
|
||||||
|
if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String {
|
||||||
|
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||||
|
}
|
||||||
|
if effort != "" {
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||||
|
} else {
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||||
|
}
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToCLI_ToolChoice_SpecificTool(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gemini-3-flash-preview",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "hi"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "json",
|
||||||
|
"description": "A JSON tool",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": {"type": "tool", "name": "json"}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" {
|
||||||
|
t.Fatalf("Expected request.toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got)
|
||||||
|
}
|
||||||
|
allowed := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array()
|
||||||
|
if len(allowed) != 1 || allowed[0].String() != "json" {
|
||||||
|
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -34,6 +34,11 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
// Model
|
// Model
|
||||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||||
|
|
||||||
|
// Let user-provided generationConfig pass through
|
||||||
|
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
|
||||||
|
out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw))
|
||||||
|
}
|
||||||
|
|
||||||
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
|
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
|
||||||
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
|
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
|
||||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user