mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-11 16:26:32 +00:00
Compare commits
299 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ebd8f0c44 | ||
|
|
b680c146c1 | ||
|
|
5c84d69d42 | ||
|
|
b48485b42b | ||
|
|
79009bb3d4 | ||
|
|
26fc611f86 | ||
|
|
b43743d4f1 | ||
|
|
179e5434b1 | ||
|
|
9f95b31158 | ||
|
|
5da07eae4c | ||
|
|
835ae178d4 | ||
|
|
c80ab8bf0d | ||
|
|
ce87714ef1 | ||
|
|
0452b869e8 | ||
|
|
d2e5857b82 | ||
|
|
f9b005f21f | ||
|
|
532107b4fa | ||
|
|
c44793789b | ||
|
|
4e99525279 | ||
|
|
7547d1d0b3 | ||
|
|
68934942d0 | ||
|
|
09fec34e1c | ||
|
|
9229708b6c | ||
|
|
914db94e79 | ||
|
|
660bd7eff5 | ||
|
|
b907d21851 | ||
|
|
d6cc976d1f | ||
|
|
8aa2cce8c5 | ||
|
|
bf9b2c49df | ||
|
|
77b42c6165 | ||
|
|
446150a747 | ||
|
|
1cbc4834e1 | ||
|
|
a8a5d03c33 | ||
|
|
76aa917882 | ||
|
|
6ac9b31e4e | ||
|
|
0ad3e8457f | ||
|
|
444a47ae63 | ||
|
|
725f4fdff4 | ||
|
|
c23e46f45d | ||
|
|
b148820c35 | ||
|
|
134f41496d | ||
|
|
c5838dd58d | ||
|
|
b6ca5ef7ce | ||
|
|
1ae994b4aa | ||
|
|
84e9793e61 | ||
|
|
32e64dacfd | ||
|
|
cc1d8f6629 | ||
|
|
5446cd2b02 | ||
|
|
8de0885b7d | ||
|
|
16243f18fd | ||
|
|
a6ce5f36e6 | ||
|
|
e73cf42e28 | ||
|
|
b45343e812 | ||
|
|
8599b1560e | ||
|
|
8bde8c37c0 | ||
|
|
82df5bf88a | ||
|
|
acb1066de8 | ||
|
|
27c68f5bb2 | ||
|
|
68dd2bfe82 | ||
|
|
65a87815e7 | ||
|
|
b80793ca82 | ||
|
|
601550f238 | ||
|
|
41b1cf2273 | ||
|
|
2baf35b3ef | ||
|
|
846e75b893 | ||
|
|
fc0257d6d9 | ||
|
|
f3c164d345 | ||
|
|
4040b1e766 | ||
|
|
3b4f9f43db | ||
|
|
37a09ecb23 | ||
|
|
0da34d3c2d | ||
|
|
74bf7eda8f | ||
|
|
9032042cfa | ||
|
|
030bf5e6c7 | ||
|
|
d3100085b0 | ||
|
|
f481d25133 | ||
|
|
8c6c90da74 | ||
|
|
24bcfd9c03 | ||
|
|
816fb4c5da | ||
|
|
c1bb77c7c9 | ||
|
|
6bcac3a55a | ||
|
|
fc346f4537 | ||
|
|
43e531a3b6 | ||
|
|
d24ea4ce2a | ||
|
|
2c30c981ae | ||
|
|
aa1da8a858 | ||
|
|
f1e9a787d7 | ||
|
|
4eeec297de | ||
|
|
77cc4ce3a0 | ||
|
|
37dfea1d3f | ||
|
|
e6626c672a | ||
|
|
c66cb0afd2 | ||
|
|
fb48eee973 | ||
|
|
bb44e5ec44 | ||
|
|
c785c1a3ca | ||
|
|
514ae341c8 | ||
|
|
0659ffab75 | ||
|
|
8ce07f38dd | ||
|
|
7cb398d167 | ||
|
|
c3e12c5e58 | ||
|
|
1825fc7503 | ||
|
|
48732ba05e | ||
|
|
acf483c9e6 | ||
|
|
3b3e0d1141 | ||
|
|
7acd428507 | ||
|
|
0aaf177640 | ||
|
|
450d1227bd | ||
|
|
492b9c46f0 | ||
|
|
6e634fe3f9 | ||
|
|
4e26182d14 | ||
|
|
8f97a5f77c | ||
|
|
eb7571936c | ||
|
|
5382764d8a | ||
|
|
49c8ec69d0 | ||
|
|
3b421c8181 | ||
|
|
21d2329947 | ||
|
|
0993413bab | ||
|
|
713388dd7b | ||
|
|
2a4d3e60f3 | ||
|
|
8b5af2ab84 | ||
|
|
e6c7af0fa9 | ||
|
|
837aa6e3aa | ||
|
|
d210be06c2 | ||
|
|
d887716ebd | ||
|
|
5dc1848466 | ||
|
|
9491517b26 | ||
|
|
9370b5bd04 | ||
|
|
abb51a0d93 | ||
|
|
c8d809131b | ||
|
|
dd71c73a9f | ||
|
|
afc8a0f9be | ||
|
|
af8e9ef458 | ||
|
|
cec6f993ad | ||
|
|
950de29f48 | ||
|
|
d6ec33e8e1 | ||
|
|
081cfe806e | ||
|
|
c1c62a6c04 | ||
|
|
a99522224f | ||
|
|
f5d46b9ca2 | ||
|
|
d693d7993b | ||
|
|
5936f9895c | ||
|
|
2fdf5d2793 | ||
|
|
b3da00d2ed | ||
|
|
740277a9f2 | ||
|
|
f91807b6b9 | ||
|
|
57d18bb226 | ||
|
|
10b9c6cb8a | ||
|
|
b24786f8a7 | ||
|
|
7b0eb41ebc | ||
|
|
70949929db | ||
|
|
7c9c89dace | ||
|
|
ef5901c81b | ||
|
|
d4829c82f7 | ||
|
|
a5f4166a9b | ||
|
|
0cbfe7f457 | ||
|
|
f2b1ec4f9e | ||
|
|
1cc21cc45b | ||
|
|
07cf616e2b | ||
|
|
2b8c466e88 | ||
|
|
ca2174ea48 | ||
|
|
c09fb2a79d | ||
|
|
4445a165e9 | ||
|
|
e92e2af71a | ||
|
|
a6bdd9a652 | ||
|
|
349a6349b3 | ||
|
|
00822770ec | ||
|
|
1a0ceda0fc | ||
|
|
b9ae4ab803 | ||
|
|
72add453d2 | ||
|
|
2789396435 | ||
|
|
61da7bd981 | ||
|
|
ae4c502792 | ||
|
|
ec6068060b | ||
|
|
ecb01d3dcd | ||
|
|
22c0c00bd4 | ||
|
|
9eb3e7a6c4 | ||
|
|
357c191510 | ||
|
|
5db244af76 | ||
|
|
dc375d1b74 | ||
|
|
9c040445af | ||
|
|
fff866424e | ||
|
|
2d12becfd6 | ||
|
|
252f7e0751 | ||
|
|
b2b17528cb | ||
|
|
55f938164b | ||
|
|
76294f0c59 | ||
|
|
2bcee78c6e | ||
|
|
7fe8246a9f | ||
|
|
93fe58e31e | ||
|
|
e5b5dc870f | ||
|
|
a54877c023 | ||
|
|
bb86a0c0c4 | ||
|
|
5fa23c7f41 | ||
|
|
f9a09b7f23 | ||
|
|
b0cde626fe | ||
|
|
e42ef9a95d | ||
|
|
abf1629ec7 | ||
|
|
73dc0b10b8 | ||
|
|
2ea95266e3 | ||
|
|
922d4141c0 | ||
|
|
1f8f198c45 | ||
|
|
c55275342c | ||
|
|
9261b0c20b | ||
|
|
7cc725496e | ||
|
|
5726a99c80 | ||
|
|
b5756bf729 | ||
|
|
709d999f9f | ||
|
|
24c18614f0 | ||
|
|
603f06a762 | ||
|
|
98f0a3e3bd | ||
|
|
e186ccb0d4 | ||
|
|
8fc0b08b70 | ||
|
|
52a257dc24 | ||
|
|
a12d907f55 | ||
|
|
453aaf8774 | ||
|
|
1b1ab1fb9b | ||
|
|
a9d0bb72da | ||
|
|
d328e54e4b | ||
|
|
5a7932cba4 | ||
|
|
1dbeb0827a | ||
|
|
2c8821891c | ||
|
|
0a2555b0f3 | ||
|
|
020df41efe | ||
|
|
f8f8cf17ce | ||
|
|
f31f7f701a | ||
|
|
b5fe78eb70 | ||
|
|
d1f667cf8d | ||
|
|
54ad7c1b6b | ||
|
|
d560c20c26 | ||
|
|
5abeca1f9e | ||
|
|
294eac3a88 | ||
|
|
a31104020c | ||
|
|
65bec4d734 | ||
|
|
edb2993838 | ||
|
|
c0d8e0dec7 | ||
|
|
795da13d5d | ||
|
|
55789df275 | ||
|
|
9e652a3540 | ||
|
|
46a6782065 | ||
|
|
c359f61859 | ||
|
|
908c8eab5b | ||
|
|
f5f2c69233 | ||
|
|
63d4de5eea | ||
|
|
af15083496 | ||
|
|
c4722e42b1 | ||
|
|
f9a991365f | ||
|
|
6df16bedba | ||
|
|
632a2fd2f2 | ||
|
|
5626637fbd | ||
|
|
2db89211a9 | ||
|
|
587371eb14 | ||
|
|
75818b1e25 | ||
|
|
a45c6defa7 | ||
|
|
cbe56955a9 | ||
|
|
8ea6ac913d | ||
|
|
ae1e8a5191 | ||
|
|
b3ccc55f09 | ||
|
|
40bee3e8d9 | ||
|
|
1ce56d7413 | ||
|
|
41a78be3a2 | ||
|
|
1ff5de9a31 | ||
|
|
46a6853046 | ||
|
|
4b2d40bd67 | ||
|
|
726f1a590c | ||
|
|
575881cb59 | ||
|
|
d02df0141b | ||
|
|
e4bc9da913 | ||
|
|
8c6be49625 | ||
|
|
c727e4251f | ||
|
|
99266be998 | ||
|
|
086d8d0d0b | ||
|
|
627dee1dac | ||
|
|
93147dddeb | ||
|
|
c0f9b15a58 | ||
|
|
6f2fbdcbae | ||
|
|
55c3197fb8 | ||
|
|
65debb874f | ||
|
|
3caadac003 | ||
|
|
6a9e3a6b84 | ||
|
|
269972440a | ||
|
|
cce13e6ad2 | ||
|
|
8a565dcad8 | ||
|
|
d536110404 | ||
|
|
48e957ddff | ||
|
|
94563d622c | ||
|
|
5a2cf0d53c | ||
|
|
2573358173 | ||
|
|
09cd3cff91 | ||
|
|
ab0bf1b517 | ||
|
|
544238772a | ||
|
|
f3ccd85ba1 | ||
|
|
9c65e17a21 | ||
|
|
ce0c6aa82b | ||
|
|
3c85d2a4d7 | ||
|
|
15bc99f6ea | ||
|
|
3ec7991e5f | ||
|
|
40e85a6759 | ||
|
|
cc116ce67d | ||
|
|
40efc2ba43 |
@@ -31,6 +31,7 @@ bin/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '>=1.24.0'
|
||||
go-version: '>=1.26.0'
|
||||
cache: true
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -3,10 +3,11 @@ cli-proxy-api
|
||||
cliproxy
|
||||
*.exe
|
||||
|
||||
|
||||
# Configuration
|
||||
config.yaml
|
||||
.env
|
||||
|
||||
.mcp.json
|
||||
# Generated content
|
||||
bin/*
|
||||
logs/*
|
||||
@@ -43,6 +44,7 @@ GEMINI.md
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.24-alpine AS builder
|
||||
FROM golang:1.26-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
58
README.md
58
README.md
@@ -10,23 +10,59 @@ The Plus release stays in lockstep with the mainline features.
|
||||
|
||||
## Differences from the Mainline
|
||||
|
||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
## New Features (Plus Enhanced)
|
||||
|
||||
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
|
||||
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
|
||||
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
|
||||
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
|
||||
- **Device Fingerprint**: Device fingerprint generation for enhanced security
|
||||
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
|
||||
- **Usage Checker**: Real-time usage monitoring and quota management
|
||||
- **Model Converter**: Unified model name conversion across providers
|
||||
- **UTF-8 Stream Processing**: Improved streaming response handling
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
|
||||
## Kiro Authentication
|
||||
|
||||
### CLI Login
|
||||
|
||||
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
|
||||
**AWS Builder ID** (recommended):
|
||||
|
||||
```bash
|
||||
# Device code flow
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# Authorization code flow
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**Import token from Kiro IDE:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
To get a token from Kiro IDE:
|
||||
|
||||
1. Open Kiro IDE and login with Google (or GitHub)
|
||||
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
||||
3. Run: `./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# Specify region
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**Additional flags:**
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--no-browser` | Don't open browser automatically, print URL instead |
|
||||
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
||||
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
||||
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
||||
|
||||
### Web-based OAuth Login
|
||||
|
||||
Access the Kiro OAuth web interface at:
|
||||
|
||||
60
README_CN.md
60
README_CN.md
@@ -10,22 +10,58 @@
|
||||
|
||||
## 与主线版本版本差异
|
||||
|
||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
|
||||
## 新增功能 (Plus 增强版)
|
||||
|
||||
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
|
||||
- **请求限流器**: 内置请求限流,防止 API 滥用
|
||||
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
|
||||
- **监控指标**: 请求指标收集,用于监控和调试
|
||||
- **设备指纹**: 设备指纹生成,增强安全性
|
||||
- **冷却管理**: 智能冷却机制,应对 API 速率限制
|
||||
- **用量检查器**: 实时用量监控和配额管理
|
||||
- **模型转换器**: 跨供应商的统一模型名称转换
|
||||
- **UTF-8 流处理**: 改进的流式响应处理
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。
|
||||
|
||||
## Kiro 认证
|
||||
智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
||||
|
||||
### 命令行登录
|
||||
|
||||
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
||||
|
||||
**AWS Builder ID**(推荐):
|
||||
|
||||
```bash
|
||||
# 设备码流程
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# 授权码流程
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**从 Kiro IDE 导入令牌:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
获取令牌步骤:
|
||||
|
||||
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
||||
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
||||
3. 运行:`./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# 指定区域
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**附加参数:**
|
||||
|
||||
| 参数 | 说明 |
|
||||
|------|------|
|
||||
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
||||
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
||||
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
||||
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
||||
|
||||
### 网页端 OAuth 登录
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -70,8 +72,10 @@ func main() {
|
||||
// Command-line flags to control the application's behavior.
|
||||
var login bool
|
||||
var codexLogin bool
|
||||
var codexDeviceLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var kiloLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var noBrowser bool
|
||||
@@ -83,19 +87,27 @@ func main() {
|
||||
var kiroAWSLogin bool
|
||||
var kiroAWSAuthCode bool
|
||||
var kiroImport bool
|
||||
var kiroIDCLogin bool
|
||||
var kiroIDCStartURL string
|
||||
var kiroIDCRegion string
|
||||
var kiroIDCFlow string
|
||||
var githubCopilotLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
var password string
|
||||
var tuiMode bool
|
||||
var standalone bool
|
||||
var noIncognito bool
|
||||
var useIncognito bool
|
||||
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
@@ -109,11 +121,17 @@ func main() {
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
|
||||
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
||||
flag.BoolVar(&kiroIDCLogin, "kiro-idc-login", false, "Login to Kiro using IAM Identity Center (IDC)")
|
||||
flag.StringVar(&kiroIDCStartURL, "kiro-idc-start-url", "", "IDC start URL (required with --kiro-idc-login)")
|
||||
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
||||
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
out := flag.CommandLine.Output()
|
||||
@@ -494,11 +512,16 @@ func main() {
|
||||
} else if codexLogin {
|
||||
// Handle Codex login
|
||||
cmd.DoCodexLogin(cfg, options)
|
||||
} else if codexDeviceLogin {
|
||||
// Handle Codex device-code login
|
||||
cmd.DoCodexDeviceLogin(cfg, options)
|
||||
} else if claudeLogin {
|
||||
// Handle Claude login
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
} else if qwenLogin {
|
||||
cmd.DoQwenLogin(cfg, options)
|
||||
} else if kiloLogin {
|
||||
cmd.DoKiloLogin(cfg, options)
|
||||
} else if iflowLogin {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
@@ -511,24 +534,34 @@ func main() {
|
||||
// Note: This config mutation is safe - auth commands exit after completion
|
||||
// and don't share config with StartService (which is in the else branch)
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroLogin(cfg, options)
|
||||
} else if kiroGoogleLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
// Note: This config mutation is safe - auth commands exit after completion
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroGoogleLogin(cfg, options)
|
||||
} else if kiroAWSLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroAWSLogin(cfg, options)
|
||||
} else if kiroAWSAuthCode {
|
||||
// For Kiro auth with authorization code flow (better UX)
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
|
||||
} else if kiroImport {
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroImport(cfg, options)
|
||||
} else if kiroIDCLogin {
|
||||
// For Kiro IDC auth, default to incognito mode for multi-account support
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroIDCLogin(cfg, options, kiroIDCStartURL, kiroIDCRegion, kiroIDCFlow)
|
||||
} else {
|
||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||
if isCloudDeploy && !configFileExists {
|
||||
@@ -536,15 +569,89 @@ func main() {
|
||||
cmd.WaitForCloudDeploy()
|
||||
return
|
||||
}
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
if tuiMode {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
hook := tui.NewLogHook(2000)
|
||||
hook.SetFormatter(&logging.LogFormatter{})
|
||||
log.AddHook(hook)
|
||||
|
||||
// 初始化并启动 Kiro token 后台刷新
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
origLogOutput := log.StandardLogger().Out
|
||||
log.SetOutput(io.Discard)
|
||||
|
||||
devNull, errOpenDevNull := os.Open(os.DevNull)
|
||||
if errOpenDevNull == nil {
|
||||
os.Stdout = devNull
|
||||
os.Stderr = devNull
|
||||
}
|
||||
|
||||
restoreIO := func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
log.SetOutput(origLogOutput)
|
||||
if devNull != nil {
|
||||
_ = devNull.Close()
|
||||
}
|
||||
}
|
||||
|
||||
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
|
||||
if password == "" {
|
||||
password = localMgmtPassword
|
||||
}
|
||||
|
||||
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
|
||||
|
||||
client := tui.NewClient(cfg.Port, password)
|
||||
ready := false
|
||||
backoff := 100 * time.Millisecond
|
||||
for i := 0; i < 30; i++ {
|
||||
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
|
||||
ready = true
|
||||
break
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
if backoff < time.Second {
|
||||
backoff = time.Duration(float64(backoff) * 1.5)
|
||||
}
|
||||
}
|
||||
|
||||
if !ready {
|
||||
restoreIO()
|
||||
cancel()
|
||||
<-done
|
||||
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
|
||||
return
|
||||
}
|
||||
|
||||
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
|
||||
restoreIO()
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
} else {
|
||||
restoreIO()
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-done
|
||||
} else {
|
||||
// Default TUI mode: pure management client.
|
||||
// The proxy server must already be running.
|
||||
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
|
||||
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
|
||||
host: ""
|
||||
host: ''
|
||||
|
||||
# Server port
|
||||
port: 8317
|
||||
@@ -8,8 +8,8 @@ port: 8317
|
||||
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
|
||||
tls:
|
||||
enable: false
|
||||
cert: ""
|
||||
key: ""
|
||||
cert: ''
|
||||
key: ''
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
@@ -20,22 +20,22 @@ remote-management:
|
||||
# Management key. If a plaintext value is provided here, it will be hashed on startup.
|
||||
# All management requests (even from localhost) require this key.
|
||||
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
|
||||
secret-key: ""
|
||||
secret-key: ''
|
||||
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
auth-dir: '~/.cli-proxy-api'
|
||||
|
||||
# API keys for authentication
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
- "your-api-key-3"
|
||||
- 'your-api-key-1'
|
||||
- 'your-api-key-2'
|
||||
- 'your-api-key-3'
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
@@ -43,7 +43,7 @@ debug: false
|
||||
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
|
||||
pprof:
|
||||
enable: false
|
||||
addr: "127.0.0.1:8316"
|
||||
addr: '127.0.0.1:8316'
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
@@ -68,14 +68,22 @@ error-logs-max-files: 10
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
proxy-url: ''
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# When true, forward filtered upstream response headers to downstream clients.
|
||||
# Default is false (disabled).
|
||||
passthrough-headers: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
# Maximum number of different credentials to try for one failed request.
|
||||
# Set to 0 to keep legacy behavior (try all available credentials).
|
||||
max-retry-credentials: 0
|
||||
|
||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||
max-retry-interval: 30
|
||||
|
||||
@@ -86,7 +94,7 @@ quota-exceeded:
|
||||
|
||||
# Routing strategy for selecting credentials when multiple match.
|
||||
routing:
|
||||
strategy: "round-robin" # round-robin (default), fill-first
|
||||
strategy: 'round-robin' # round-robin (default), fill-first
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
@@ -160,17 +168,43 @@ nonstream-keepalive-interval: 0
|
||||
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# timeout: "600"
|
||||
|
||||
# Kiro (AWS CodeWhisperer) configuration
|
||||
# Note: Kiro API currently only operates in us-east-1 region
|
||||
#kiro:
|
||||
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
|
||||
# agent-task-type: "" # optional: "vibe" or empty (API default)
|
||||
# start-url: "https://your-company.awsapps.com/start" # optional: IDC start URL (preset for login)
|
||||
# region: "us-east-1" # optional: OIDC region for IDC login and token refresh
|
||||
# - access-token: "aoaAAAAA..." # or provide tokens directly
|
||||
# refresh-token: "aorAAAAA..."
|
||||
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
|
||||
|
||||
# Kilocode (OAuth-based code assistant)
|
||||
# Note: Kilocode uses OAuth device flow authentication.
|
||||
# Use the CLI command: ./server --kilo-login
|
||||
# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
|
||||
# oauth-model-alias:
|
||||
# kilo:
|
||||
# - name: "minimax/minimax-m2.5:free"
|
||||
# alias: "minimax-m2.5"
|
||||
# - name: "z-ai/glm-5:free"
|
||||
# alias: "glm-5"
|
||||
# oauth-excluded-models:
|
||||
# kilo:
|
||||
# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
|
||||
# - "*:free" # wildcard matching suffix (e.g. all free models)
|
||||
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
|
||||
@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||
}()
|
||||
return ch, nil
|
||||
return &clipexec.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
|
||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||
}
|
||||
|
||||
|
||||
23
go.mod
23
go.mod
@@ -1,9 +1,13 @@
|
||||
module github.com/router-for-me/CLIProxyAPI/v6
|
||||
|
||||
go 1.24.0
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/charmbracelet/bubbles v1.0.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.0
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
@@ -33,8 +37,16 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
@@ -42,6 +54,7 @@ require (
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||
@@ -58,19 +71,27 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sergi/go-diff v1.4.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
|
||||
45
go.sum
45
go.sum
@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
|
||||
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
||||
@@ -101,8 +127,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||
@@ -114,6 +146,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||
@@ -124,6 +162,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
@@ -161,6 +201,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -168,12 +210,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -189,9 +190,21 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
|
||||
}
|
||||
|
||||
// When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
|
||||
useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
|
||||
|
||||
var requestBody io.Reader
|
||||
if body.Data != "" {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
if useCBORPayload {
|
||||
cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
|
||||
if errEncode != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
|
||||
return
|
||||
}
|
||||
requestBody = bytes.NewReader(cborPayload)
|
||||
} else {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
}
|
||||
}
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
|
||||
@@ -234,10 +247,18 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// For CBOR upstream responses, decode into plain text or JSON string before returning.
|
||||
responseBodyText := string(respBody)
|
||||
if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
|
||||
if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
|
||||
responseBodyText = decodedBody
|
||||
}
|
||||
}
|
||||
|
||||
response := apiCallResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header,
|
||||
Body: string(respBody),
|
||||
Body: responseBodyText,
|
||||
}
|
||||
|
||||
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
|
||||
@@ -747,6 +768,83 @@ func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
|
||||
if len(headers) == 0 {
|
||||
return false
|
||||
}
|
||||
for key, value := range headers {
|
||||
if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
|
||||
func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
|
||||
var payload any
|
||||
if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
|
||||
return nil, errUnmarshal
|
||||
}
|
||||
return cbor.Marshal(payload)
|
||||
}
|
||||
|
||||
// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
|
||||
func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var payload any
|
||||
if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
jsonCompatible := cborValueToJSONCompatible(payload)
|
||||
switch typed := jsonCompatible.(type) {
|
||||
case string:
|
||||
return typed, nil
|
||||
case []byte:
|
||||
return string(typed), nil
|
||||
default:
|
||||
jsonBytes, errMarshal := json.Marshal(jsonCompatible)
|
||||
if errMarshal != nil {
|
||||
return "", errMarshal
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
}
|
||||
|
||||
// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
|
||||
func cborValueToJSONCompatible(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[any]any:
|
||||
out := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
out[key] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
out[i] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return typed
|
||||
}
|
||||
}
|
||||
|
||||
// QuotaDetail represents quota information for a specific resource type
|
||||
type QuotaDetail struct {
|
||||
Entitlement float64 `json:"entitlement"`
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -29,6 +30,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
@@ -47,14 +49,11 @@ import (
|
||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||
|
||||
const (
|
||||
anthropicCallbackPort = 54545
|
||||
geminiCallbackPort = 8085
|
||||
codexCallbackPort = 1455
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
geminiCLIApiClient = "gl-node/22.17.0"
|
||||
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
anthropicCallbackPort = 54545
|
||||
geminiCallbackPort = 8085
|
||||
codexCallbackPort = 1455
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
)
|
||||
|
||||
type callbackForwarder struct {
|
||||
@@ -194,17 +193,6 @@ func startCallbackForwarder(port int, provider, targetBase string) (*callbackFor
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func stopCallbackForwarder(port int) {
|
||||
callbackForwardersMu.Lock()
|
||||
forwarder := callbackForwarders[port]
|
||||
if forwarder != nil {
|
||||
delete(callbackForwarders, port)
|
||||
}
|
||||
callbackForwardersMu.Unlock()
|
||||
|
||||
stopForwarderInstance(port, forwarder)
|
||||
}
|
||||
|
||||
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||
if forwarder == nil {
|
||||
return
|
||||
@@ -411,6 +399,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
if !auth.LastRefreshedAt.IsZero() {
|
||||
entry["last_refresh"] = auth.LastRefreshedAt
|
||||
}
|
||||
if !auth.NextRetryAfter.IsZero() {
|
||||
entry["next_retry_after"] = auth.NextRetryAfter
|
||||
}
|
||||
if path != "" {
|
||||
entry["path"] = path
|
||||
entry["source"] = "file"
|
||||
@@ -643,44 +634,85 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
full := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
if !filepath.IsAbs(full) {
|
||||
if abs, errAbs := filepath.Abs(full); errAbs == nil {
|
||||
full = abs
|
||||
|
||||
targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
||||
targetID := ""
|
||||
if targetAuth := h.findAuthForDelete(name); targetAuth != nil {
|
||||
targetID = strings.TrimSpace(targetAuth.ID)
|
||||
if path := strings.TrimSpace(authAttribute(targetAuth, "path")); path != "" {
|
||||
targetPath = path
|
||||
}
|
||||
}
|
||||
if err := os.Remove(full); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if !filepath.IsAbs(targetPath) {
|
||||
if abs, errAbs := filepath.Abs(targetPath); errAbs == nil {
|
||||
targetPath = abs
|
||||
}
|
||||
}
|
||||
if errRemove := os.Remove(targetPath); errRemove != nil {
|
||||
if os.IsNotExist(errRemove) {
|
||||
c.JSON(404, gin.H{"error": "file not found"})
|
||||
} else {
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", err)})
|
||||
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)})
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.deleteTokenRecord(ctx, full); err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil {
|
||||
c.JSON(500, gin.H{"error": errDeleteRecord.Error()})
|
||||
return
|
||||
}
|
||||
h.disableAuth(ctx, full)
|
||||
if targetID != "" {
|
||||
h.disableAuth(ctx, targetID)
|
||||
} else {
|
||||
h.disableAuth(ctx, targetPath)
|
||||
}
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *Handler) findAuthForDelete(name string) *coreauth.Auth {
|
||||
if h == nil || h.authManager == nil {
|
||||
return nil
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
if auth, ok := h.authManager.GetByID(name); ok {
|
||||
return auth
|
||||
}
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(auth.FileName) == name {
|
||||
return auth
|
||||
}
|
||||
if filepath.Base(strings.TrimSpace(authAttribute(auth, "path"))) == name {
|
||||
return auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) authIDForPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if h == nil || h.cfg == nil {
|
||||
return path
|
||||
id := path
|
||||
if h != nil && h.cfg != nil {
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if authDir != "" {
|
||||
if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
}
|
||||
}
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if authDir == "" {
|
||||
return path
|
||||
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
|
||||
if runtime.GOOS == "windows" {
|
||||
id = strings.ToLower(id)
|
||||
}
|
||||
if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" {
|
||||
return rel
|
||||
}
|
||||
return path
|
||||
return id
|
||||
}
|
||||
|
||||
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
|
||||
@@ -813,14 +845,104 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find auth by name or ID
|
||||
var targetAuth *coreauth.Auth
|
||||
if auth, ok := h.authManager.GetByID(name); ok {
|
||||
targetAuth = auth
|
||||
} else {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name {
|
||||
targetAuth = auth
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if targetAuth == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
if req.Prefix != nil {
|
||||
targetAuth.Prefix = *req.Prefix
|
||||
changed = true
|
||||
}
|
||||
if req.ProxyURL != nil {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
changed = true
|
||||
}
|
||||
if req.Priority != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
|
||||
return
|
||||
}
|
||||
|
||||
targetAuth.UpdatedAt = time.Now()
|
||||
|
||||
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
}
|
||||
authID := h.authIDForPath(id)
|
||||
if authID == "" {
|
||||
authID = strings.TrimSpace(id)
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if auth, ok := h.authManager.GetByID(id); ok {
|
||||
auth.Disabled = true
|
||||
auth.Status = coreauth.StatusDisabled
|
||||
auth.StatusMessage = "removed via management API"
|
||||
auth.UpdatedAt = time.Now()
|
||||
_, _ = h.authManager.Update(ctx, auth)
|
||||
return
|
||||
}
|
||||
authID := h.authIDForPath(id)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
@@ -869,11 +991,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("token store unavailable")
|
||||
}
|
||||
if h.postAuthHook != nil {
|
||||
if err := h.postAuthHook(ctx, record); err != nil {
|
||||
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||
}
|
||||
}
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Claude authentication...")
|
||||
|
||||
@@ -1018,6 +1146,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||
|
||||
@@ -1193,6 +1322,30 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
ts.Checked = true
|
||||
} else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") {
|
||||
ts.Auto = false
|
||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
@@ -1252,6 +1405,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Codex authentication...")
|
||||
|
||||
@@ -1397,6 +1551,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Antigravity authentication...")
|
||||
|
||||
@@ -1561,6 +1716,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Qwen authentication...")
|
||||
|
||||
@@ -1616,6 +1772,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
@@ -1692,6 +1849,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing iFlow authentication...")
|
||||
|
||||
@@ -1811,8 +1969,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
|
||||
|
||||
// Initialize Copilot auth service
|
||||
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
|
||||
// Assuming copilot package is imported as "copilot"
|
||||
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
|
||||
|
||||
// Initiate device flow
|
||||
@@ -1826,7 +1982,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
authURL := deviceCode.VerificationURI
|
||||
userCode := deviceCode.UserCode
|
||||
|
||||
RegisterOAuthSession(state, "github")
|
||||
RegisterOAuthSession(state, "github-copilot")
|
||||
|
||||
go func() {
|
||||
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
|
||||
@@ -1838,9 +1994,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
if errUser != nil {
|
||||
log.Warnf("Failed to fetch user info: %v", errUser)
|
||||
}
|
||||
|
||||
username := userInfo.Login
|
||||
if username == "" {
|
||||
username = "github-user"
|
||||
}
|
||||
|
||||
@@ -1849,18 +2009,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
TokenType: tokenData.TokenType,
|
||||
Scope: tokenData.Scope,
|
||||
Username: username,
|
||||
Email: userInfo.Email,
|
||||
Name: userInfo.Name,
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("github-%s.json", username)
|
||||
fileName := fmt.Sprintf("github-copilot-%s.json", username)
|
||||
label := userInfo.Email
|
||||
if label == "" {
|
||||
label = username
|
||||
}
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "github",
|
||||
Provider: "github-copilot",
|
||||
Label: label,
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{
|
||||
"email": username,
|
||||
"email": userInfo.Email,
|
||||
"username": username,
|
||||
"name": userInfo.Name,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1874,7 +2042,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use GitHub Copilot services through this CLI")
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("github")
|
||||
CompleteOAuthSessionsByProvider("github-copilot")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
@@ -2124,7 +2292,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
@@ -2212,9 +2421,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
|
||||
return fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
||||
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
@@ -2284,7 +2491,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||
@@ -2305,7 +2512,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
resp, errDo = httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||
@@ -2374,6 +2581,15 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
// PopulateAuthContext extracts request info and adds it to the context
|
||||
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||
info := &coreauth.RequestInfo{
|
||||
Query: c.Request.URL.Query(),
|
||||
Headers: c.Request.Header,
|
||||
}
|
||||
return coreauth.WithRequestInfo(ctx, info)
|
||||
}
|
||||
|
||||
const kiroCallbackPort = 9876
|
||||
|
||||
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
@@ -2510,6 +2726,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
||||
if errTarget != nil {
|
||||
@@ -2517,7 +2734,8 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -2526,7 +2744,7 @@ func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(kiroCallbackPort)
|
||||
defer stopCallbackForwarderInstance(kiroCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
||||
@@ -2668,3 +2886,88 @@ func generateKiroPKCE() (verifier, challenge string, err error) {
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKiloToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
fmt.Println("Initializing Kilo authentication...")
|
||||
|
||||
state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
|
||||
kilocodeAuth := kilo.NewKiloAuth()
|
||||
|
||||
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initiate device flow: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "kilo")
|
||||
|
||||
go func() {
|
||||
fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
|
||||
|
||||
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
|
||||
if err != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to fetch profile: %v", err)
|
||||
profile = &kilo.Profile{Email: status.UserEmail}
|
||||
}
|
||||
|
||||
var orgID string
|
||||
if len(profile.Orgs) > 0 {
|
||||
orgID = profile.Orgs[0].ID
|
||||
}
|
||||
|
||||
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
|
||||
if err != nil {
|
||||
defaults = &kilo.Defaults{}
|
||||
}
|
||||
|
||||
ts := &kilo.KiloTokenStorage{
|
||||
Token: status.Token,
|
||||
OrganizationID: orgID,
|
||||
Model: defaults.Model,
|
||||
Email: status.UserEmail,
|
||||
Type: "kilo",
|
||||
}
|
||||
|
||||
fileName := kilo.CredentialFileName(status.UserEmail)
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kilo",
|
||||
FileName: fileName,
|
||||
Storage: ts,
|
||||
Metadata: map[string]any{
|
||||
"email": status.UserEmail,
|
||||
"organization_id": orgID,
|
||||
"model": defaults.Model,
|
||||
},
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("kilo")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ok",
|
||||
"url": resp.VerificationURL,
|
||||
"state": state,
|
||||
"user_code": resp.Code,
|
||||
"verification_uri": resp.VerificationURL,
|
||||
})
|
||||
}
|
||||
|
||||
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestDeleteAuthFile_UsesAuthPathFromManager(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
authDir := filepath.Join(tempDir, "auth")
|
||||
externalDir := filepath.Join(tempDir, "external")
|
||||
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||
}
|
||||
if errMkdirExternal := os.MkdirAll(externalDir, 0o700); errMkdirExternal != nil {
|
||||
t.Fatalf("failed to create external dir: %v", errMkdirExternal)
|
||||
}
|
||||
|
||||
fileName := "codex-user@example.com-plus.json"
|
||||
shadowPath := filepath.Join(authDir, fileName)
|
||||
realPath := filepath.Join(externalDir, fileName)
|
||||
if errWriteShadow := os.WriteFile(shadowPath, []byte(`{"type":"codex","email":"shadow@example.com"}`), 0o600); errWriteShadow != nil {
|
||||
t.Fatalf("failed to write shadow file: %v", errWriteShadow)
|
||||
}
|
||||
if errWriteReal := os.WriteFile(realPath, []byte(`{"type":"codex","email":"real@example.com"}`), 0o600); errWriteReal != nil {
|
||||
t.Fatalf("failed to write real file: %v", errWriteReal)
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
record := &coreauth.Auth{
|
||||
ID: "legacy/" + fileName,
|
||||
FileName: fileName,
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusError,
|
||||
Unavailable: true,
|
||||
Attributes: map[string]string{
|
||||
"path": realPath,
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"type": "codex",
|
||||
"email": "real@example.com",
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
deleteRec := httptest.NewRecorder()
|
||||
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||
deleteCtx.Request = deleteReq
|
||||
h.DeleteAuthFile(deleteCtx)
|
||||
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
if _, errStatReal := os.Stat(realPath); !os.IsNotExist(errStatReal) {
|
||||
t.Fatalf("expected managed auth file to be removed, stat err: %v", errStatReal)
|
||||
}
|
||||
if _, errStatShadow := os.Stat(shadowPath); errStatShadow != nil {
|
||||
t.Fatalf("expected shadow auth file to remain, stat err: %v", errStatShadow)
|
||||
}
|
||||
|
||||
listRec := httptest.NewRecorder()
|
||||
listCtx, _ := gin.CreateTestContext(listRec)
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
|
||||
listCtx.Request = listReq
|
||||
h.ListAuthFiles(listCtx)
|
||||
|
||||
if listRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, listRec.Code, listRec.Body.String())
|
||||
}
|
||||
var listPayload map[string]any
|
||||
if errUnmarshal := json.Unmarshal(listRec.Body.Bytes(), &listPayload); errUnmarshal != nil {
|
||||
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
|
||||
}
|
||||
filesRaw, ok := listPayload["files"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected files array, payload: %#v", listPayload)
|
||||
}
|
||||
if len(filesRaw) != 0 {
|
||||
t.Fatalf("expected removed auth to be hidden from list, got %d entries", len(filesRaw))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAuthFile_FallbackToAuthDirPath(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
fileName := "fallback-user.json"
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex"}`), 0o600); errWrite != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWrite)
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
deleteRec := httptest.NewRecorder()
|
||||
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||
deleteCtx.Request = deleteReq
|
||||
h.DeleteAuthFile(deleteCtx)
|
||||
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
if _, errStat := os.Stat(filePath); !os.IsNotExist(errStat) {
|
||||
t.Fatalf("expected auth file to be removed from auth dir, stat err: %v", errStat)
|
||||
}
|
||||
}
|
||||
@@ -28,8 +28,7 @@ func (h *Handler) GetConfig(c *gin.Context) {
|
||||
c.JSON(200, gin.H{})
|
||||
return
|
||||
}
|
||||
cfgCopy := *h.cfg
|
||||
c.JSON(200, &cfgCopy)
|
||||
c.JSON(200, new(*h.cfg))
|
||||
}
|
||||
|
||||
type releaseInfo struct {
|
||||
|
||||
@@ -796,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthModelAlias, channel)
|
||||
if len(h.cfg.OAuthModelAlias) == 0 {
|
||||
h.cfg.OAuthModelAlias = nil
|
||||
}
|
||||
// Set to nil instead of deleting the key so that the "explicitly disabled"
|
||||
// marker survives config reload and prevents SanitizeOAuthModelAlias from
|
||||
// re-injecting default aliases (fixes #222).
|
||||
h.cfg.OAuthModelAlias[channel] = nil
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ type Handler struct {
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
postAuthHook coreauth.PostAuthHook
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
||||
h.logDir = dir
|
||||
}
|
||||
|
||||
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||
h.postAuthHook = hook
|
||||
}
|
||||
|
||||
// Middleware enforces access control for management endpoints.
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
|
||||
@@ -15,10 +15,12 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
||||
// logger, it still captures data so that upstream errors can be persisted.
|
||||
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
||||
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Method == http.MethodGet {
|
||||
if shouldSkipMethodForRequestLogging(c.Request) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
loggerEnabled := logger.IsEnabled()
|
||||
|
||||
// Capture request information
|
||||
requestInfo, err := captureRequestInfo(c)
|
||||
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
||||
if err != nil {
|
||||
// Log error but continue processing
|
||||
// In a real implementation, you might want to use a proper logger here
|
||||
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
if !logger.IsEnabled() {
|
||||
if !loggerEnabled {
|
||||
wrapper.logOnErrorOnly = true
|
||||
}
|
||||
c.Writer = wrapper
|
||||
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
||||
if req == nil {
|
||||
return true
|
||||
}
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
return !isResponsesWebsocketUpgrade(req)
|
||||
}
|
||||
|
||||
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
||||
if req == nil || req.URL == nil {
|
||||
return false
|
||||
}
|
||||
if req.URL.Path != "/v1/responses" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
||||
}
|
||||
|
||||
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
||||
if loggerEnabled {
|
||||
return true
|
||||
}
|
||||
if req == nil || req.Body == nil {
|
||||
return false
|
||||
}
|
||||
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return false
|
||||
}
|
||||
if req.ContentLength <= 0 {
|
||||
return false
|
||||
}
|
||||
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
||||
}
|
||||
|
||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||
// It captures the URL, method, headers, and body. The request body is read and then
|
||||
// restored so that it can be processed by subsequent handlers.
|
||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
||||
// Capture URL with sensitive query parameters masked
|
||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
url := c.Request.URL.Path
|
||||
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
|
||||
// Capture request body
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
if captureBody && c.Request.Body != nil {
|
||||
// Read the body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
|
||||
138
internal/api/middleware/request_logging_test.go
Normal file
138
internal/api/middleware/request_logging_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
skip bool
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
req: nil,
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "post request should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "plain get should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/models"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "responses websocket upgrade should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{"Upgrade": []string{"websocket"}},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "responses get without upgrade should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
||||
if got != tests[i].skip {
|
||||
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldCaptureRequestBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
loggerEnabled bool
|
||||
req *http.Request
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "logger enabled always captures",
|
||||
loggerEnabled: true,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nil request",
|
||||
loggerEnabled: false,
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "small known size json in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: 2,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large known size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unknown size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multipart skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: 1,
|
||||
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
||||
if got != tests[i].want {
|
||||
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
|
||||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
if c != nil {
|
||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return w.requestInfo.Body
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if len(w.requestInfo.Body) > 0 {
|
||||
requestBody = w.requestInfo.Body
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
|
||||
43
internal/api/middleware/response_writer_test.go
Normal file
43
internal/api/middleware/response_writer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{
|
||||
requestInfo: &RequestInfo{Body: []byte("original-body")},
|
||||
}
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "original-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "original-body")
|
||||
}
|
||||
|
||||
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
|
||||
body = wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-as-string" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||
}
|
||||
}
|
||||
@@ -127,8 +127,7 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
m.lastConfig = &settingsCopy
|
||||
m.lastConfig = new(settings)
|
||||
|
||||
// Initialize localhost restriction setting (hot-reloadable)
|
||||
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -77,6 +78,9 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
req.Header.Del("X-Api-Key")
|
||||
req.Header.Del("X-Goog-Api-Key")
|
||||
|
||||
// Remove proxy, client identity, and browser fingerprint headers
|
||||
misc.ScrubProxyAndFingerprintHeaders(req)
|
||||
|
||||
// Remove query-based credentials if they match the authenticated client API key.
|
||||
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||
// breaking unrelated upstream query parameters.
|
||||
@@ -215,7 +219,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
|
||||
// Don't log as error for context canceled - it's usually client closing connection
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
||||
return
|
||||
} else {
|
||||
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||
}
|
||||
|
||||
@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
|
||||
// Test that context.Canceled errors return 499 without generic error response
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a canceled context to trigger the cancellation path
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Directly invoke the ErrorHandler with context.Canceled
|
||||
proxy.ErrorHandler(rr, req, context.Canceled)
|
||||
|
||||
// Body should be empty for canceled requests (no JSON error response)
|
||||
body := rr.Body.Bytes()
|
||||
if len(body) > 0 {
|
||||
t.Fatalf("expected empty body for canceled context, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
||||
// Upstream returns gzipped JSON without Content-Encoding header
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -52,6 +52,7 @@ type serverOptionConfig struct {
|
||||
keepAliveEnabled bool
|
||||
keepAliveTimeout time.Duration
|
||||
keepAliveOnTimeout func()
|
||||
postAuthHook auth.PostAuthHook
|
||||
}
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
@@ -59,10 +60,8 @@ type ServerOption func(*serverOptionConfig)
|
||||
|
||||
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
||||
configDir := filepath.Dir(configPath)
|
||||
if base := util.WritablePath(); base != "" {
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir, cfg.ErrorLogsMaxFiles)
|
||||
}
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir, cfg.ErrorLogsMaxFiles)
|
||||
logsDir := logging.ResolveLogDirectory(cfg)
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, logsDir, configDir, cfg.ErrorLogsMaxFiles)
|
||||
}
|
||||
|
||||
// WithMiddleware appends additional Gin middleware during server construction.
|
||||
@@ -112,6 +111,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
||||
}
|
||||
}
|
||||
|
||||
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||
return func(cfg *serverOptionConfig) {
|
||||
cfg.postAuthHook = hook
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents the main API server.
|
||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||
type Server struct {
|
||||
@@ -252,7 +258,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
s.applyAccessConfig(nil, cfg)
|
||||
if authManager != nil {
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
@@ -263,6 +269,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
}
|
||||
logDir := logging.ResolveLogDirectory(cfg)
|
||||
s.mgmt.SetLogDirectory(logDir)
|
||||
if optionState.postAuthHook != nil {
|
||||
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||
}
|
||||
s.localPassword = optionState.localPassword
|
||||
|
||||
// Setup routes
|
||||
@@ -285,8 +294,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
optionState.routerConfigurator(engine, s.handlers, cfg)
|
||||
}
|
||||
|
||||
// Register management routes when configuration or environment secrets are available.
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
|
||||
// Register management routes when configuration or environment secrets are available,
|
||||
// or when a local management password is provided (e.g. TUI mode).
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
|
||||
s.managementRoutesEnabled.Store(hasManagementSecret)
|
||||
if hasManagementSecret {
|
||||
s.registerManagementRoutes()
|
||||
@@ -329,6 +339,7 @@ func (s *Server) setupRoutes() {
|
||||
v1.POST("/completions", openaiHandlers.Completions)
|
||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
|
||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||
}
|
||||
@@ -642,6 +653,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
@@ -649,6 +661,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
@@ -931,7 +944,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||
}
|
||||
|
||||
// Update log level dynamically when debug flag changes
|
||||
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gin "github.com/gin-gonic/gin"
|
||||
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
@@ -109,3 +111,100 @@ func TestAmpProviderModelRoutes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||
t.Setenv("WRITABLE_PATH", "")
|
||||
t.Setenv("writable_path", "")
|
||||
|
||||
originalWD, errGetwd := os.Getwd()
|
||||
if errGetwd != nil {
|
||||
t.Fatalf("failed to get current working directory: %v", errGetwd)
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
if errChdir := os.Chdir(tmpDir); errChdir != nil {
|
||||
t.Fatalf("failed to switch working directory: %v", errChdir)
|
||||
}
|
||||
defer func() {
|
||||
if errChdirBack := os.Chdir(originalWD); errChdirBack != nil {
|
||||
t.Fatalf("failed to restore working directory: %v", errChdirBack)
|
||||
}
|
||||
}()
|
||||
|
||||
// Force ResolveLogDirectory to fallback to auth-dir/logs by making ./logs not a writable directory.
|
||||
if errWriteFile := os.WriteFile(filepath.Join(tmpDir, "logs"), []byte("not-a-directory"), 0o644); errWriteFile != nil {
|
||||
t.Fatalf("failed to create blocking logs file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(tmpDir, "config")
|
||||
if errMkdirConfig := os.MkdirAll(configDir, 0o755); errMkdirConfig != nil {
|
||||
t.Fatalf("failed to create config dir: %v", errMkdirConfig)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "config.yaml")
|
||||
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||
}
|
||||
|
||||
cfg := &proxyconfig.Config{
|
||||
SDKConfig: proxyconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
},
|
||||
AuthDir: authDir,
|
||||
ErrorLogsMaxFiles: 10,
|
||||
}
|
||||
|
||||
logger := defaultRequestLoggerFactory(cfg, configPath)
|
||||
fileLogger, ok := logger.(*internallogging.FileRequestLogger)
|
||||
if !ok {
|
||||
t.Fatalf("expected *FileRequestLogger, got %T", logger)
|
||||
}
|
||||
|
||||
errLog := fileLogger.LogRequestWithOptions(
|
||||
"/v1/chat/completions",
|
||||
http.MethodPost,
|
||||
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||
[]byte(`{"input":"hello"}`),
|
||||
http.StatusBadGateway,
|
||||
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||
[]byte(`{"error":"upstream failure"}`),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
"issue-1711",
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
)
|
||||
if errLog != nil {
|
||||
t.Fatalf("failed to write forced error request log: %v", errLog)
|
||||
}
|
||||
|
||||
authLogsDir := filepath.Join(authDir, "logs")
|
||||
authEntries, errReadAuthDir := os.ReadDir(authLogsDir)
|
||||
if errReadAuthDir != nil {
|
||||
t.Fatalf("failed to read auth logs dir %s: %v", authLogsDir, errReadAuthDir)
|
||||
}
|
||||
foundErrorLogInAuthDir := false
|
||||
for _, entry := range authEntries {
|
||||
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||
foundErrorLogInAuthDir = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundErrorLogInAuthDir {
|
||||
t.Fatalf("expected forced error log in auth fallback dir %s, got entries: %+v", authLogsDir, authEntries)
|
||||
}
|
||||
|
||||
configLogsDir := filepath.Join(configDir, "logs")
|
||||
configEntries, errReadConfigDir := os.ReadDir(configLogsDir)
|
||||
if errReadConfigDir != nil && !os.IsNotExist(errReadConfigDir) {
|
||||
t.Fatalf("failed to inspect config logs dir %s: %v", configLogsDir, errReadConfigDir)
|
||||
}
|
||||
for _, entry := range configEntries {
|
||||
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||
t.Fatalf("unexpected forced error log in config dir %s", configLogsDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
// OAuth configuration constants for Claude/Anthropic
|
||||
const (
|
||||
AuthURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
TokenURL = "https://api.anthropic.com/v1/oauth/token"
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
RedirectURI = "http://localhost:54545/callback"
|
||||
)
|
||||
|
||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
||||
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
// Encode and write the token data as JSON
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Firefox fingerprint
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
|
||||
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||
type utlsRoundTripper struct {
|
||||
// mu protects the connections map and pending map
|
||||
@@ -100,7 +100,9 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
// createConnection creates a new HTTP/2 connection with Firefox TLS fingerprint
|
||||
// createConnection creates a new HTTP/2 connection with Chrome TLS fingerprint.
|
||||
// Chrome's TLS fingerprint is closer to Node.js/OpenSSL (which real Claude Code uses)
|
||||
// than Firefox, reducing the mismatch between TLS layer and HTTP headers.
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := t.dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
@@ -108,7 +110,7 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloFirefox_Auto)
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
@@ -156,7 +158,7 @@ func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
|
||||
// for Anthropic domains by using utls with Firefox fingerprint.
|
||||
// for Anthropic domains by using utls with Chrome fingerprint.
|
||||
// It accepts optional SDK configuration for proxy settings.
|
||||
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
|
||||
return &http.Client{
|
||||
|
||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||
// authorization code and PKCE verifier.
|
||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||
// login while preserving the existing token parsing and storage behavior.
|
||||
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||
}
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||
}
|
||||
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {ClientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
|
||||
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
if isNonRetryableRefreshErr(err) {
|
||||
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func isNonRetryableRefreshErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
raw := strings.ToLower(err.Error())
|
||||
return strings.Contains(raw, "refresh_token_reused")
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||
// This is typically called after a successful token refresh to persist the new credentials.
|
||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||
|
||||
44
internal/auth/codex/openai_auth_test.go
Normal file
44
internal/auth/codex/openai_auth_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||
var calls int32
|
||||
auth := &CodexAuth{
|
||||
httpClient: &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for non-retryable refresh failure")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -82,15 +82,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
|
||||
}
|
||||
|
||||
// Fetch the GitHub username
|
||||
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
if err != nil {
|
||||
log.Warnf("copilot: failed to fetch user info: %v", err)
|
||||
username = "unknown"
|
||||
}
|
||||
|
||||
username := userInfo.Login
|
||||
if username == "" {
|
||||
username = "github-user"
|
||||
}
|
||||
|
||||
return &CopilotAuthBundle{
|
||||
TokenData: tokenData,
|
||||
Username: username,
|
||||
Email: userInfo.Email,
|
||||
Name: userInfo.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -150,12 +156,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
return true, username, nil
|
||||
return true, userInfo.Login, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
||||
@@ -165,6 +171,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
Username: bundle.Username,
|
||||
Email: bundle.Email,
|
||||
Name: bundle.Name,
|
||||
Type: "github-copilot",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", copilotClientID)
|
||||
data.Set("scope", "user:email")
|
||||
data.Set("scope", "read:user user:email")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves the GitHub username for the authenticated user.
|
||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||
// GitHubUserInfo holds GitHub user profile information.
|
||||
type GitHubUserInfo struct {
|
||||
// Login is the GitHub username.
|
||||
Login string
|
||||
// Email is the primary email address (may be empty if not public).
|
||||
Email string
|
||||
// Name is the display name.
|
||||
Name string
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
|
||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
|
||||
if accessToken == "" {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
||||
if err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
}
|
||||
|
||||
var userInfo struct {
|
||||
var raw struct {
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
|
||||
if userInfo.Login == "" {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||
if raw.Login == "" {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||
}
|
||||
|
||||
return userInfo.Login, nil
|
||||
return GitHubUserInfo{
|
||||
Login: raw.Login,
|
||||
Email: raw.Email,
|
||||
Name: raw.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
213
internal/auth/copilot/oauth_test.go
Normal file
213
internal/auth/copilot/oauth_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// roundTripFunc lets us inject a custom transport for testing.
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
|
||||
// regardless of the original URL host.
|
||||
func newTestClient(srv *httptest.Server) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req2 := req.Clone(req.Context())
|
||||
req2.URL.Scheme = "http"
|
||||
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
|
||||
return srv.Client().Transport.RoundTrip(req2)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
|
||||
func TestFetchUserInfo_FullProfile(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"login": "octocat",
|
||||
"email": "octocat@github.com",
|
||||
"name": "The Octocat",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Login != "octocat" {
|
||||
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
|
||||
}
|
||||
if info.Email != "octocat@github.com" {
|
||||
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
|
||||
}
|
||||
if info.Name != "The Octocat" {
|
||||
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
|
||||
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// GitHub returns null for private emails.
|
||||
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Login != "privateuser" {
|
||||
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
|
||||
}
|
||||
if info.Email != "" {
|
||||
t.Errorf("Email: got %q, want empty string", info.Email)
|
||||
}
|
||||
if info.Name != "Private User" {
|
||||
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
|
||||
func TestFetchUserInfo_EmptyToken(t *testing.T) {
|
||||
client := &DeviceFlowClient{httpClient: http.DefaultClient}
|
||||
_, err := client.FetchUserInfo(context.Background(), "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty token, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
|
||||
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
_, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty login, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
|
||||
func TestFetchUserInfo_HTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
_, err := client.FetchUserInfo(context.Background(), "bad-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 401 response, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
|
||||
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
|
||||
ts := &CopilotTokenStorage{
|
||||
AccessToken: "ghu_abc",
|
||||
TokenType: "bearer",
|
||||
Scope: "read:user user:email",
|
||||
Username: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err = json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
|
||||
if _, ok := out[key]; !ok {
|
||||
t.Errorf("expected key %q in JSON output, not found", key)
|
||||
}
|
||||
}
|
||||
if out["email"] != "octocat@github.com" {
|
||||
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
|
||||
}
|
||||
if out["name"] != "The Octocat" {
|
||||
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
|
||||
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
|
||||
ts := &CopilotTokenStorage{
|
||||
AccessToken: "ghu_abc",
|
||||
Username: "octocat",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err = json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := out["email"]; ok {
|
||||
t.Error("email key should be omitted when empty (omitempty), but was present")
|
||||
}
|
||||
if _, ok := out["name"]; ok {
|
||||
t.Error("name key should be omitted when empty (omitempty), but was present")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
|
||||
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
|
||||
bundle := &CopilotAuthBundle{
|
||||
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
|
||||
Username: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
}
|
||||
if bundle.Email != "octocat@github.com" {
|
||||
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
|
||||
}
|
||||
if bundle.Name != "The Octocat" {
|
||||
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
|
||||
func TestGitHubUserInfo_Struct(t *testing.T) {
|
||||
info := GitHubUserInfo{
|
||||
Login: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
}
|
||||
if info.Login == "" || info.Email == "" || info.Name == "" {
|
||||
t.Error("GitHubUserInfo fields should not be empty")
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
// Username is the GitHub username associated with this token.
|
||||
Username string `json:"username"`
|
||||
// Email is the GitHub email address associated with this token.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Name is the GitHub display name associated with this token.
|
||||
Name string `json:"name,omitempty"`
|
||||
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
|
||||
TokenData *CopilotTokenData
|
||||
// Username is the GitHub username.
|
||||
Username string
|
||||
// Email is the GitHub email address.
|
||||
Email string
|
||||
// Name is the GitHub display name.
|
||||
Name string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents GitHub's device code response.
|
||||
|
||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
||||
|
||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "gemini"
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
||||
Scope string `json:"scope"`
|
||||
Cookie string `json:"cookie"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serialises the token storage to disk.
|
||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
168
internal/auth/kilo/kilo_auth.go
Normal file
168
internal/auth/kilo/kilo_auth.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// BaseURL is the base URL for the Kilo AI API.
|
||||
BaseURL = "https://api.kilo.ai/api"
|
||||
)
|
||||
|
||||
// DeviceAuthResponse represents the response from initiating device flow.
|
||||
type DeviceAuthResponse struct {
|
||||
Code string `json:"code"`
|
||||
VerificationURL string `json:"verificationUrl"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// DeviceStatusResponse represents the response when polling for device flow status.
|
||||
type DeviceStatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
UserEmail string `json:"userEmail"`
|
||||
}
|
||||
|
||||
// Profile represents the user profile from Kilo AI.
|
||||
type Profile struct {
|
||||
Email string `json:"email"`
|
||||
Orgs []Organization `json:"organizations"`
|
||||
}
|
||||
|
||||
// Organization represents a Kilo AI organization.
|
||||
type Organization struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Defaults represents default settings for an organization or user.
|
||||
type Defaults struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// KiloAuth provides methods for handling the Kilo AI authentication flow.
|
||||
type KiloAuth struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewKiloAuth creates a new instance of KiloAuth.
|
||||
func NewKiloAuth() *KiloAuth {
|
||||
return &KiloAuth{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateDeviceFlow starts the device authentication flow.
|
||||
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
|
||||
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var data DeviceAuthResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// PollForToken polls for the device flow completion.
|
||||
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var data DeviceStatusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch data.Status {
|
||||
case "approved":
|
||||
return &data, nil
|
||||
case "denied", "expired":
|
||||
return nil, fmt.Errorf("device flow %s", data.Status)
|
||||
case "pending":
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown status: %s", data.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile fetches the user's profile.
|
||||
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get profile request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var profile Profile
|
||||
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// GetDefaults fetches default settings for an organization.
|
||||
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
|
||||
url := BaseURL + "/defaults"
|
||||
if orgID != "" {
|
||||
url = BaseURL + "/organizations/" + orgID + "/defaults"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var defaults Defaults
|
||||
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &defaults, nil
|
||||
}
|
||||
60
internal/auth/kilo/kilo_token.go
Normal file
60
internal/auth/kilo/kilo_token.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KiloTokenStorage stores token information for Kilo AI authentication.
|
||||
type KiloTokenStorage struct {
|
||||
// Token is the Kilo access token.
|
||||
Token string `json:"kilocodeToken"`
|
||||
|
||||
// OrganizationID is the Kilo organization ID.
|
||||
OrganizationID string `json:"kilocodeOrganizationId"`
|
||||
|
||||
// Model is the default model to use.
|
||||
Model string `json:"kilocodeModel"`
|
||||
|
||||
// Email is the email address of the authenticated user.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Type indicates the authentication provider type, always "kilo" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
|
||||
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kilo"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used to persist Kilo credentials.
|
||||
func CredentialFileName(email string) string {
|
||||
return fmt.Sprintf("kilo-%s.json", email)
|
||||
}
|
||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
if err = encoder.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -7,10 +7,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
@@ -47,7 +50,7 @@ type KiroTokenData struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Region is the AWS region for IDC authentication (only for IDC auth method)
|
||||
// Region is the OIDC region for IDC login and token refresh
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
@@ -520,3 +523,159 @@ func GenerateTokenFileName(tokenData *KiroTokenData) string {
|
||||
// Priority 3: Fallback to authMethod only with sequence
|
||||
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
|
||||
}
|
||||
|
||||
// DefaultKiroRegion is the fallback region when none is specified.
|
||||
const DefaultKiroRegion = "us-east-1"
|
||||
|
||||
// GetCodeWhispererLegacyEndpoint returns the legacy CodeWhisperer JSON-RPC endpoint.
|
||||
// This endpoint supports JSON-RPC style requests with x-amz-target headers.
|
||||
// The Q endpoint (q.{region}.amazonaws.com) does NOT support JSON-RPC style.
|
||||
func GetCodeWhispererLegacyEndpoint(region string) string {
|
||||
if region == "" {
|
||||
region = DefaultKiroRegion
|
||||
}
|
||||
return "https://codewhisperer." + region + ".amazonaws.com"
|
||||
}
|
||||
|
||||
// ProfileARN represents a parsed AWS CodeWhisperer profile ARN.
|
||||
// ARN format: arn:partition:service:region:account-id:resource-type/resource-id
|
||||
// Example: arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL
|
||||
type ProfileARN struct {
|
||||
// Raw is the original ARN string
|
||||
Raw string
|
||||
// Partition is the AWS partition (aws)
|
||||
Partition string
|
||||
// Service is the AWS service name (codewhisperer)
|
||||
Service string
|
||||
// Region is the AWS region (us-east-1, ap-southeast-1, etc.)
|
||||
Region string
|
||||
// AccountID is the AWS account ID
|
||||
AccountID string
|
||||
// ResourceType is the resource type (profile)
|
||||
ResourceType string
|
||||
// ResourceID is the resource identifier (e.g., ABCDEFGHIJKL)
|
||||
ResourceID string
|
||||
}
|
||||
|
||||
// ParseProfileARN parses an AWS ARN string into a ProfileARN struct.
|
||||
// Returns nil if the ARN is empty, invalid, or not a codewhisperer ARN.
|
||||
func ParseProfileARN(arn string) *ProfileARN {
|
||||
if arn == "" {
|
||||
return nil
|
||||
}
|
||||
// ARN format: arn:partition:service:region:account-id:resource
|
||||
// Minimum 6 parts separated by ":"
|
||||
parts := strings.Split(arn, ":")
|
||||
if len(parts) < 6 {
|
||||
log.Warnf("invalid ARN format: %s", arn)
|
||||
return nil
|
||||
}
|
||||
// Validate ARN prefix
|
||||
if parts[0] != "arn" {
|
||||
return nil
|
||||
}
|
||||
// Validate partition
|
||||
partition := parts[1]
|
||||
if partition == "" {
|
||||
return nil
|
||||
}
|
||||
// Validate service is codewhisperer
|
||||
service := parts[2]
|
||||
if service != "codewhisperer" {
|
||||
return nil
|
||||
}
|
||||
// Validate region format (must contain "-")
|
||||
region := parts[3]
|
||||
if region == "" || !strings.Contains(region, "-") {
|
||||
return nil
|
||||
}
|
||||
// Account ID
|
||||
accountID := parts[4]
|
||||
|
||||
// Parse resource (format: resource-type/resource-id)
|
||||
// Join remaining parts in case resource contains ":"
|
||||
resource := strings.Join(parts[5:], ":")
|
||||
resourceType := ""
|
||||
resourceID := ""
|
||||
if idx := strings.Index(resource, "/"); idx > 0 {
|
||||
resourceType = resource[:idx]
|
||||
resourceID = resource[idx+1:]
|
||||
} else {
|
||||
resourceType = resource
|
||||
}
|
||||
|
||||
return &ProfileARN{
|
||||
Raw: arn,
|
||||
Partition: partition,
|
||||
Service: service,
|
||||
Region: region,
|
||||
AccountID: accountID,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: resourceID,
|
||||
}
|
||||
}
|
||||
|
||||
// GetKiroAPIEndpoint returns the Q API endpoint for the specified region.
|
||||
// If region is empty, defaults to us-east-1.
|
||||
func GetKiroAPIEndpoint(region string) string {
|
||||
if region == "" {
|
||||
region = DefaultKiroRegion
|
||||
}
|
||||
return "https://q." + region + ".amazonaws.com"
|
||||
}
|
||||
|
||||
// GetKiroAPIEndpointFromProfileArn extracts region from profileArn and returns the endpoint.
|
||||
// Returns default us-east-1 endpoint if region cannot be extracted.
|
||||
func GetKiroAPIEndpointFromProfileArn(profileArn string) string {
|
||||
region := ExtractRegionFromProfileArn(profileArn)
|
||||
return GetKiroAPIEndpoint(region)
|
||||
}
|
||||
|
||||
// ExtractRegionFromProfileArn extracts the AWS region from a ProfileARN string.
|
||||
// Returns empty string if ARN is invalid or region cannot be extracted.
|
||||
func ExtractRegionFromProfileArn(profileArn string) string {
|
||||
parsed := ParseProfileARN(profileArn)
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
return parsed.Region
|
||||
}
|
||||
|
||||
// ExtractRegionFromMetadata extracts API region from auth metadata.
|
||||
// Priority: api_region > profile_arn > DefaultKiroRegion
|
||||
func ExtractRegionFromMetadata(metadata map[string]interface{}) string {
|
||||
if metadata == nil {
|
||||
return DefaultKiroRegion
|
||||
}
|
||||
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := metadata["api_region"].(string); ok && r != "" {
|
||||
return r
|
||||
}
|
||||
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if region := ExtractRegionFromProfileArn(profileArn); region != "" {
|
||||
return region
|
||||
}
|
||||
}
|
||||
|
||||
return DefaultKiroRegion
|
||||
}
|
||||
|
||||
func buildURL(endpoint, path string, queryParams map[string]string) string {
|
||||
fullURL := fmt.Sprintf("%s/%s", endpoint, path)
|
||||
if len(queryParams) > 0 {
|
||||
values := url.Values{}
|
||||
for key, value := range queryParams {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
values.Set(key, value)
|
||||
}
|
||||
if encoded := values.Encode(); encoded != "" {
|
||||
fullURL = fullURL + "?" + encoded
|
||||
}
|
||||
}
|
||||
return fullURL
|
||||
}
|
||||
|
||||
@@ -19,15 +19,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.)
|
||||
// Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com)
|
||||
// used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct
|
||||
// for their respective API operations.
|
||||
awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com"
|
||||
defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json"
|
||||
targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits"
|
||||
targetListModels = "AmazonCodeWhispererService.ListAvailableModels"
|
||||
targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse"
|
||||
pathGetUsageLimits = "getUsageLimits"
|
||||
pathListAvailableModels = "ListAvailableModels"
|
||||
)
|
||||
|
||||
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
|
||||
@@ -35,7 +28,6 @@ const (
|
||||
// and communicating with the CodeWhisperer API.
|
||||
type KiroAuth struct {
|
||||
httpClient *http.Client
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// NewKiroAuth creates a new Kiro authentication service.
|
||||
@@ -49,7 +41,6 @@ type KiroAuth struct {
|
||||
func NewKiroAuth(cfg *config.Config) *KiroAuth {
|
||||
return &KiroAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,33 +101,30 @@ func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
|
||||
return time.Now().After(expiresAt)
|
||||
}
|
||||
|
||||
// makeRequest sends a request to the CodeWhisperer API.
|
||||
// This is an internal method for making authenticated API calls.
|
||||
// makeRequest sends a REST-style GET request to the CodeWhisperer API.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits")
|
||||
// - accessToken: The OAuth access token
|
||||
// - payload: The request payload
|
||||
// - path: The API path (e.g., "getUsageLimits")
|
||||
// - tokenData: The token data containing access token, refresh token, and profile ARN
|
||||
// - queryParams: Query parameters to add to the URL
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The response body
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) {
|
||||
jsonBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
func (k *KiroAuth) makeRequest(ctx context.Context, path string, tokenData *KiroTokenData, queryParams map[string]string) ([]byte, error) {
|
||||
// Get endpoint from profileArn (defaults to us-east-1 if empty)
|
||||
profileArn := queryParams["profileArn"]
|
||||
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
|
||||
url := buildURL(endpoint, path, queryParams)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody)))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", target)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
|
||||
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
|
||||
|
||||
resp, err := k.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -171,13 +159,13 @@ func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken s
|
||||
// - *KiroUsageInfo: The usage information
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
|
||||
payload := map[string]interface{}{
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
|
||||
body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload)
|
||||
body, err := k.makeRequest(ctx, pathGetUsageLimits, tokenData, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -221,12 +209,12 @@ func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData)
|
||||
// - []*KiroModel: The list of available models
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
|
||||
payload := map[string]interface{}{
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
}
|
||||
|
||||
body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload)
|
||||
body, err := k.makeRequest(ctx, pathListAvailableModels, tokenData, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package kiro
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -217,7 +218,8 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenData *KiroTokenData
|
||||
expected string
|
||||
exact string // exact match (for cases with email)
|
||||
prefix string // prefix match (for cases without email, where sequence is appended)
|
||||
}{
|
||||
{
|
||||
name: "IDC with email",
|
||||
@@ -226,7 +228,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
Email: "user@example.com",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
expected: "kiro-idc-user-example-com.json",
|
||||
exact: "kiro-idc-user-example-com.json",
|
||||
},
|
||||
{
|
||||
name: "IDC without email but with startUrl",
|
||||
@@ -235,7 +237,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
Email: "",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
expected: "kiro-idc-d-1234567890.json",
|
||||
prefix: "kiro-idc-d-1234567890-",
|
||||
},
|
||||
{
|
||||
name: "IDC with company name in startUrl",
|
||||
@@ -244,7 +246,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
Email: "",
|
||||
StartURL: "https://my-company.awsapps.com/start",
|
||||
},
|
||||
expected: "kiro-idc-my-company.json",
|
||||
prefix: "kiro-idc-my-company-",
|
||||
},
|
||||
{
|
||||
name: "IDC without email and without startUrl",
|
||||
@@ -253,7 +255,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
Email: "",
|
||||
StartURL: "",
|
||||
},
|
||||
expected: "kiro-idc.json",
|
||||
prefix: "kiro-idc-",
|
||||
},
|
||||
{
|
||||
name: "Builder ID with email",
|
||||
@@ -262,7 +264,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
Email: "user@gmail.com",
|
||||
StartURL: "https://view.awsapps.com/start",
|
||||
},
|
||||
expected: "kiro-builder-id-user-gmail-com.json",
|
||||
exact: "kiro-builder-id-user-gmail-com.json",
|
||||
},
|
||||
{
|
||||
name: "Builder ID without email",
|
||||
@@ -271,7 +273,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
Email: "",
|
||||
StartURL: "https://view.awsapps.com/start",
|
||||
},
|
||||
expected: "kiro-builder-id.json",
|
||||
prefix: "kiro-builder-id-",
|
||||
},
|
||||
{
|
||||
name: "Social auth with email",
|
||||
@@ -279,7 +281,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
AuthMethod: "google",
|
||||
Email: "user@gmail.com",
|
||||
},
|
||||
expected: "kiro-google-user-gmail-com.json",
|
||||
exact: "kiro-google-user-gmail-com.json",
|
||||
},
|
||||
{
|
||||
name: "Empty auth method",
|
||||
@@ -287,7 +289,7 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
AuthMethod: "",
|
||||
Email: "",
|
||||
},
|
||||
expected: "kiro-unknown.json",
|
||||
prefix: "kiro-unknown-",
|
||||
},
|
||||
{
|
||||
name: "Email with special characters",
|
||||
@@ -296,16 +298,454 @@ func TestGenerateTokenFileName(t *testing.T) {
|
||||
Email: "user.name+tag@sub.example.com",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
expected: "kiro-idc-user-name+tag-sub-example-com.json",
|
||||
exact: "kiro-idc-user-name+tag-sub-example-com.json",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GenerateTokenFileName(tt.tokenData)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.expected)
|
||||
if tt.exact != "" {
|
||||
if result != tt.exact {
|
||||
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.exact)
|
||||
}
|
||||
} else if tt.prefix != "" {
|
||||
if !strings.HasPrefix(result, tt.prefix) || !strings.HasSuffix(result, ".json") {
|
||||
t.Errorf("GenerateTokenFileName() = %q, want prefix %q with .json suffix", result, tt.prefix)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseProfileARN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arn string
|
||||
expected *ProfileARN
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN",
|
||||
arn: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid format - too few parts",
|
||||
arn: "arn:aws:codewhisperer",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid prefix - not arn",
|
||||
arn: "notarn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid service - not codewhisperer",
|
||||
arn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid region - no hyphen",
|
||||
arn: "arn:aws:codewhisperer:useast1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty partition",
|
||||
arn: "arn::codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty region",
|
||||
arn: "arn:aws:codewhisperer::123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-east-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ABCDEFGHIJKL",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
arn: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "ap-southeast-1",
|
||||
AccountID: "987654321098",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ZYXWVUTSRQ",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-west-1",
|
||||
arn: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "eu-west-1",
|
||||
AccountID: "111222333444",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "PROFILE123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - aws-cn partition",
|
||||
arn: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
|
||||
Partition: "aws-cn",
|
||||
Service: "codewhisperer",
|
||||
Region: "cn-north-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "CHINAID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - resource without slash",
|
||||
arn: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-west-2",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - resource with colon",
|
||||
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-east-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ABC:extra",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ParseProfileARN(tt.arn)
|
||||
if tt.expected == nil {
|
||||
if result != nil {
|
||||
t.Errorf("ParseProfileARN(%q) = %+v, want nil", tt.arn, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
t.Errorf("ParseProfileARN(%q) = nil, want %+v", tt.arn, tt.expected)
|
||||
return
|
||||
}
|
||||
if result.Raw != tt.expected.Raw {
|
||||
t.Errorf("Raw = %q, want %q", result.Raw, tt.expected.Raw)
|
||||
}
|
||||
if result.Partition != tt.expected.Partition {
|
||||
t.Errorf("Partition = %q, want %q", result.Partition, tt.expected.Partition)
|
||||
}
|
||||
if result.Service != tt.expected.Service {
|
||||
t.Errorf("Service = %q, want %q", result.Service, tt.expected.Service)
|
||||
}
|
||||
if result.Region != tt.expected.Region {
|
||||
t.Errorf("Region = %q, want %q", result.Region, tt.expected.Region)
|
||||
}
|
||||
if result.AccountID != tt.expected.AccountID {
|
||||
t.Errorf("AccountID = %q, want %q", result.AccountID, tt.expected.AccountID)
|
||||
}
|
||||
if result.ResourceType != tt.expected.ResourceType {
|
||||
t.Errorf("ResourceType = %q, want %q", result.ResourceType, tt.expected.ResourceType)
|
||||
}
|
||||
if result.ResourceID != tt.expected.ResourceID {
|
||||
t.Errorf("ResourceID = %q, want %q", result.ResourceID, tt.expected.ResourceID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRegionFromProfileArn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileArn string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN",
|
||||
profileArn: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Invalid ARN",
|
||||
profileArn: "invalid-arn",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
expected: "ap-southeast-1",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-central-1",
|
||||
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
expected: "eu-central-1",
|
||||
},
|
||||
{
|
||||
name: "Non-codewhisperer ARN",
|
||||
profileArn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractRegionFromProfileArn(tt.profileArn)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractRegionFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroAPIEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
region string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty region - defaults to us-east-1",
|
||||
region: "",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-east-1",
|
||||
region: "us-east-1",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-west-2",
|
||||
region: "us-west-2",
|
||||
expected: "https://q.us-west-2.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "ap-southeast-1",
|
||||
region: "ap-southeast-1",
|
||||
expected: "https://q.ap-southeast-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "eu-west-1",
|
||||
region: "eu-west-1",
|
||||
expected: "https://q.eu-west-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "cn-north-1",
|
||||
region: "cn-north-1",
|
||||
expected: "https://q.cn-north-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetKiroAPIEndpoint(tt.region)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetKiroAPIEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroAPIEndpointFromProfileArn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileArn string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN - defaults to us-east-1",
|
||||
profileArn: "",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Invalid ARN - defaults to us-east-1",
|
||||
profileArn: "invalid-arn",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
expected: "https://q.ap-southeast-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-central-1",
|
||||
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
expected: "https://q.eu-central-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetKiroAPIEndpointFromProfileArn(tt.profileArn)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetKiroAPIEndpointFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCodeWhispererLegacyEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
region string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty region - defaults to us-east-1",
|
||||
region: "",
|
||||
expected: "https://codewhisperer.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-east-1",
|
||||
region: "us-east-1",
|
||||
expected: "https://codewhisperer.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-west-2",
|
||||
region: "us-west-2",
|
||||
expected: "https://codewhisperer.us-west-2.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "ap-northeast-1",
|
||||
region: "ap-northeast-1",
|
||||
expected: "https://codewhisperer.ap-northeast-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetCodeWhispererLegacyEndpoint(tt.region)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetCodeWhispererLegacyEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRegionFromMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Nil metadata - defaults to us-east-1",
|
||||
metadata: nil,
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Empty metadata - defaults to us-east-1",
|
||||
metadata: map[string]interface{}{},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 1: api_region override",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "eu-west-1",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "eu-west-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 2: profile_arn when api_region is empty",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "",
|
||||
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "ap-southeast-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 2: profile_arn when api_region is missing",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "eu-central-1",
|
||||
},
|
||||
{
|
||||
name: "Fallback: default when profile_arn is invalid",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "invalid-arn",
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Fallback: default when profile_arn is empty",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "",
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "OIDC region is NOT used for API region",
|
||||
metadata: map[string]interface{}{
|
||||
"region": "ap-northeast-2", // OIDC region - should be ignored
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "api_region takes precedence over OIDC region",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "us-west-2",
|
||||
"region": "ap-northeast-2", // OIDC region - should be ignored
|
||||
},
|
||||
expected: "us-west-2",
|
||||
},
|
||||
{
|
||||
name: "Non-string api_region is ignored",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": 123, // wrong type
|
||||
"profile_arn": "arn:aws:codewhisperer:ap-south-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "ap-south-1",
|
||||
},
|
||||
{
|
||||
name: "Non-string profile_arn is ignored",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": 123, // wrong type
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractRegionFromMetadata(tt.metadata)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractRegionFromMetadata(%v) = %q, want %q", tt.metadata, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,30 +9,23 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codeWhispererAPI = "https://codewhisperer.us-east-1.amazonaws.com"
|
||||
kiroVersion = "0.6.18"
|
||||
)
|
||||
|
||||
// CodeWhispererClient handles CodeWhisperer API calls.
|
||||
type CodeWhispererClient struct {
|
||||
httpClient *http.Client
|
||||
machineID string
|
||||
}
|
||||
|
||||
// UsageLimitsResponse represents the getUsageLimits API response.
|
||||
type UsageLimitsResponse struct {
|
||||
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
UserInfo *UserInfo `json:"userInfo,omitempty"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
|
||||
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
UserInfo *UserInfo `json:"userInfo,omitempty"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo contains user information from the API.
|
||||
@@ -49,13 +42,13 @@ type SubscriptionInfo struct {
|
||||
|
||||
// UsageBreakdown contains usage details.
|
||||
type UsageBreakdown struct {
|
||||
UsageLimit *int `json:"usageLimit,omitempty"`
|
||||
CurrentUsage *int `json:"currentUsage,omitempty"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
UsageLimit *int `json:"usageLimit,omitempty"`
|
||||
CurrentUsage *int `json:"currentUsage,omitempty"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
}
|
||||
|
||||
// NewCodeWhispererClient creates a new CodeWhisperer client.
|
||||
@@ -64,40 +57,34 @@ func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhisperer
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
if machineID == "" {
|
||||
machineID = uuid.New().String()
|
||||
}
|
||||
return &CodeWhispererClient{
|
||||
httpClient: client,
|
||||
machineID: machineID,
|
||||
}
|
||||
}
|
||||
|
||||
// generateInvocationID generates a unique invocation ID.
|
||||
func generateInvocationID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
|
||||
// This is the recommended way to get user email after login.
|
||||
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken string) (*UsageLimitsResponse, error) {
|
||||
url := fmt.Sprintf("%s/getUsageLimits?isEmailRequired=true&origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", codeWhispererAPI)
|
||||
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken, clientID, refreshToken, profileArn string) (*UsageLimitsResponse, error) {
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
// Determine endpoint based on profileArn region
|
||||
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
|
||||
if profileArn != "" {
|
||||
queryParams["profileArn"] = profileArn
|
||||
} else {
|
||||
queryParams["isEmailRequired"] = "true"
|
||||
}
|
||||
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers to match Kiro IDE
|
||||
xAmzUserAgent := fmt.Sprintf("aws-sdk-js/1.0.0 KiroIDE-%s-%s", kiroVersion, c.machineID)
|
||||
userAgent := fmt.Sprintf("aws-sdk-js/1.0.0 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererruntime#1.0.0 m/E KiroIDE-%s-%s", kiroVersion, c.machineID)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("x-amz-user-agent", xAmzUserAgent)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("amz-sdk-invocation-id", generateInvocationID())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
|
||||
req.Header.Set("Connection", "close")
|
||||
accountKey := GetAccountKey(clientID, refreshToken)
|
||||
setRuntimeHeaders(req, accessToken, accountKey)
|
||||
|
||||
log.Debugf("codewhisperer: GET %s", url)
|
||||
|
||||
@@ -128,8 +115,8 @@ func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken st
|
||||
|
||||
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
|
||||
// This is more reliable than JWT parsing as it uses the official API.
|
||||
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken string) string {
|
||||
resp, err := c.GetUsageLimits(ctx, accessToken)
|
||||
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken, clientID, refreshToken string) string {
|
||||
resp, err := c.GetUsageLimits(ctx, accessToken, clientID, refreshToken, "")
|
||||
if err != nil {
|
||||
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
|
||||
return ""
|
||||
@@ -146,10 +133,10 @@ func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessT
|
||||
|
||||
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
|
||||
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
|
||||
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken string) string {
|
||||
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken, clientID, refreshToken string) string {
|
||||
// Method 1: Try CodeWhisperer API (most reliable)
|
||||
cwClient := NewCodeWhispererClient(cfg, "")
|
||||
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken)
|
||||
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken, clientID, refreshToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
@@ -2,77 +2,105 @@ package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Fingerprint 多维度指纹信息
|
||||
// Fingerprint holds multi-dimensional fingerprint data for runtime request disguise.
|
||||
type Fingerprint struct {
|
||||
SDKVersion string // 1.0.20-1.0.27
|
||||
OIDCSDKVersion string // 3.7xx (AWS SDK JS)
|
||||
RuntimeSDKVersion string // 1.0.x (runtime API)
|
||||
StreamingSDKVersion string // 1.0.x (streaming API)
|
||||
OSType string // darwin/windows/linux
|
||||
OSVersion string // 10.0.22621
|
||||
NodeVersion string // 18.x/20.x/22.x
|
||||
KiroVersion string // 0.3.x-0.8.x
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string // SHA256
|
||||
AcceptLanguage string
|
||||
ScreenResolution string // 1920x1080
|
||||
ColorDepth int // 24
|
||||
HardwareConcurrency int // CPU 核心数
|
||||
TimezoneOffset int
|
||||
}
|
||||
|
||||
// FingerprintManager 指纹管理器
|
||||
// FingerprintConfig holds external fingerprint overrides.
|
||||
type FingerprintConfig struct {
|
||||
OIDCSDKVersion string
|
||||
RuntimeSDKVersion string
|
||||
StreamingSDKVersion string
|
||||
OSType string
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string
|
||||
}
|
||||
|
||||
// FingerprintManager manages per-account fingerprint generation and caching.
|
||||
type FingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
||||
rng *rand.Rand
|
||||
config *FingerprintConfig // External config (Optional)
|
||||
}
|
||||
|
||||
var (
|
||||
sdkVersions = []string{
|
||||
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
|
||||
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
|
||||
// SDK versions
|
||||
oidcSDKVersions = []string{
|
||||
"3.980.0", "3.975.0", "3.972.0", "3.808.0",
|
||||
"3.738.0", "3.737.0", "3.736.0", "3.735.0",
|
||||
}
|
||||
// SDKVersions for getUsageLimits/ListAvailableModels/GetProfile (runtime API)
|
||||
runtimeSDKVersions = []string{"1.0.0"}
|
||||
// SDKVersions for generateAssistantResponse (streaming API)
|
||||
streamingSDKVersions = []string{"1.0.27"}
|
||||
// Valid OS types
|
||||
osTypes = []string{"darwin", "windows", "linux"}
|
||||
// OS versions
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
|
||||
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
|
||||
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
|
||||
"darwin": {"25.2.0", "25.1.0", "25.0.0", "24.5.0", "24.4.0", "24.3.0"},
|
||||
"windows": {"10.0.26200", "10.0.26100", "10.0.22631", "10.0.22621", "10.0.19045"},
|
||||
"linux": {"6.12.0", "6.11.0", "6.8.0", "6.6.0", "6.5.0", "6.1.0"},
|
||||
}
|
||||
// Node versions
|
||||
nodeVersions = []string{
|
||||
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
|
||||
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
|
||||
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
|
||||
"22.21.1", "22.21.0", "22.20.0", "22.19.0", "22.18.0",
|
||||
"20.18.0", "20.17.0", "20.16.0",
|
||||
}
|
||||
// Kiro IDE versions
|
||||
kiroVersions = []string{
|
||||
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
|
||||
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
|
||||
"0.10.32", "0.10.16", "0.10.10",
|
||||
"0.9.47", "0.9.40", "0.9.2",
|
||||
"0.8.206", "0.8.140", "0.8.135", "0.8.86",
|
||||
}
|
||||
acceptLanguages = []string{
|
||||
"en-US,en;q=0.9",
|
||||
"en-GB,en;q=0.9",
|
||||
"zh-CN,zh;q=0.9,en;q=0.8",
|
||||
"zh-TW,zh;q=0.9,en;q=0.8",
|
||||
"ja-JP,ja;q=0.9,en;q=0.8",
|
||||
"ko-KR,ko;q=0.9,en;q=0.8",
|
||||
"de-DE,de;q=0.9,en;q=0.8",
|
||||
"fr-FR,fr;q=0.9,en;q=0.8",
|
||||
}
|
||||
screenResolutions = []string{
|
||||
"1920x1080", "2560x1440", "3840x2160",
|
||||
"1366x768", "1440x900", "1680x1050",
|
||||
"2560x1600", "3440x1440",
|
||||
}
|
||||
colorDepths = []int{24, 32}
|
||||
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
|
||||
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
|
||||
// Global singleton
|
||||
globalFingerprintManager *FingerprintManager
|
||||
globalFingerprintManagerOnce sync.Once
|
||||
)
|
||||
|
||||
// NewFingerprintManager 创建指纹管理器
|
||||
func GlobalFingerprintManager() *FingerprintManager {
|
||||
globalFingerprintManagerOnce.Do(func() {
|
||||
globalFingerprintManager = NewFingerprintManager()
|
||||
})
|
||||
return globalFingerprintManager
|
||||
}
|
||||
|
||||
func SetGlobalFingerprintConfig(cfg *FingerprintConfig) {
|
||||
GlobalFingerprintManager().SetConfig(cfg)
|
||||
}
|
||||
|
||||
// SetConfig applies the config and clears the fingerprint cache.
|
||||
func (fm *FingerprintManager) SetConfig(cfg *FingerprintConfig) {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
fm.config = cfg
|
||||
// Clear cached fingerprints so they regenerate with the new config
|
||||
fm.fingerprints = make(map[string]*Fingerprint)
|
||||
}
|
||||
|
||||
func NewFingerprintManager() *FingerprintManager {
|
||||
return &FingerprintManager{
|
||||
fingerprints: make(map[string]*Fingerprint),
|
||||
@@ -80,7 +108,7 @@ func NewFingerprintManager() *FingerprintManager {
|
||||
}
|
||||
}
|
||||
|
||||
// GetFingerprint 获取或生成 Token 关联的指纹
|
||||
// GetFingerprint returns the fingerprint for tokenKey, creating one if it doesn't exist.
|
||||
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||
fm.mu.RLock()
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
@@ -101,97 +129,150 @@ func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||
return fp
|
||||
}
|
||||
|
||||
// generateFingerprint 生成新的指纹
|
||||
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
||||
osType := fm.randomChoice(osTypes)
|
||||
osVersion := fm.randomChoice(osVersions[osType])
|
||||
kiroVersion := fm.randomChoice(kiroVersions)
|
||||
if fm.config != nil {
|
||||
return fm.generateFromConfig(tokenKey)
|
||||
}
|
||||
return fm.generateRandom(tokenKey)
|
||||
}
|
||||
|
||||
fp := &Fingerprint{
|
||||
SDKVersion: fm.randomChoice(sdkVersions),
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: fm.randomChoice(nodeVersions),
|
||||
KiroVersion: kiroVersion,
|
||||
AcceptLanguage: fm.randomChoice(acceptLanguages),
|
||||
ScreenResolution: fm.randomChoice(screenResolutions),
|
||||
ColorDepth: fm.randomIntChoice(colorDepths),
|
||||
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
|
||||
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
|
||||
// generateFromConfig uses config values, falling back to random for empty fields.
|
||||
func (fm *FingerprintManager) generateFromConfig(tokenKey string) *Fingerprint {
|
||||
cfg := fm.config
|
||||
|
||||
// Helper: config value or random selection
|
||||
configOrRandom := func(configVal string, choices []string) string {
|
||||
if configVal != "" {
|
||||
return configVal
|
||||
}
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
|
||||
return fp
|
||||
osType := cfg.OSType
|
||||
if osType == "" {
|
||||
osType = runtime.GOOS
|
||||
if !slices.Contains(osTypes, osType) {
|
||||
osType = osTypes[fm.rng.Intn(len(osTypes))]
|
||||
}
|
||||
}
|
||||
|
||||
osVersion := cfg.OSVersion
|
||||
if osVersion == "" {
|
||||
if versions, ok := osVersions[osType]; ok {
|
||||
osVersion = versions[fm.rng.Intn(len(versions))]
|
||||
}
|
||||
}
|
||||
|
||||
kiroHash := cfg.KiroHash
|
||||
if kiroHash == "" {
|
||||
hash := sha256.Sum256([]byte(tokenKey))
|
||||
kiroHash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
return &Fingerprint{
|
||||
OIDCSDKVersion: configOrRandom(cfg.OIDCSDKVersion, oidcSDKVersions),
|
||||
RuntimeSDKVersion: configOrRandom(cfg.RuntimeSDKVersion, runtimeSDKVersions),
|
||||
StreamingSDKVersion: configOrRandom(cfg.StreamingSDKVersion, streamingSDKVersions),
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: configOrRandom(cfg.NodeVersion, nodeVersions),
|
||||
KiroVersion: configOrRandom(cfg.KiroVersion, kiroVersions),
|
||||
KiroHash: kiroHash,
|
||||
}
|
||||
}
|
||||
|
||||
// generateKiroHash 生成 Kiro Hash
|
||||
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
|
||||
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
// generateRandom generates a deterministic fingerprint seeded by accountKey hash.
|
||||
func (fm *FingerprintManager) generateRandom(accountKey string) *Fingerprint {
|
||||
// Use accountKey hash as seed for deterministic random selection
|
||||
hash := sha256.Sum256([]byte(accountKey))
|
||||
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
osType := runtime.GOOS
|
||||
if !slices.Contains(osTypes, osType) {
|
||||
osType = osTypes[rng.Intn(len(osTypes))]
|
||||
}
|
||||
osVersion := osVersions[osType][rng.Intn(len(osVersions[osType]))]
|
||||
|
||||
return &Fingerprint{
|
||||
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
|
||||
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
|
||||
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
|
||||
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
|
||||
KiroHash: hex.EncodeToString(hash[:]),
|
||||
}
|
||||
}
|
||||
|
||||
// randomChoice 随机选择字符串
|
||||
func (fm *FingerprintManager) randomChoice(choices []string) string {
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
// GenerateAccountKey returns a 16-char hex key derived from SHA256(seed).
|
||||
func GenerateAccountKey(seed string) string {
|
||||
hash := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(hash[:8])
|
||||
}
|
||||
|
||||
// randomIntChoice 随机选择整数
|
||||
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
// GetAccountKey derives an account key from clientID > refreshToken > random UUID.
|
||||
func GetAccountKey(clientID, refreshToken string) string {
|
||||
// 1. Prefer ClientID
|
||||
if clientID != "" {
|
||||
return GenerateAccountKey(clientID)
|
||||
}
|
||||
|
||||
// 2. Fallback to RefreshToken
|
||||
if refreshToken != "" {
|
||||
return GenerateAccountKey(refreshToken)
|
||||
}
|
||||
|
||||
// 3. Random fallback
|
||||
return GenerateAccountKey(uuid.New().String())
|
||||
}
|
||||
|
||||
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
|
||||
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
|
||||
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
|
||||
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
|
||||
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
|
||||
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
|
||||
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
|
||||
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
|
||||
req.Header.Set("Accept-Language", fp.AcceptLanguage)
|
||||
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
|
||||
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
|
||||
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
|
||||
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
|
||||
}
|
||||
|
||||
// RemoveFingerprint 移除 Token 关联的指纹
|
||||
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
delete(fm.fingerprints, tokenKey)
|
||||
}
|
||||
|
||||
// Count 返回当前管理的指纹数量
|
||||
func (fm *FingerprintManager) Count() int {
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
return len(fm.fingerprints)
|
||||
}
|
||||
|
||||
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
|
||||
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||
// BuildUserAgent format: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.SDKVersion,
|
||||
fp.StreamingSDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.SDKVersion,
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
|
||||
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||
// BuildAmzUserAgent format: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.SDKVersion,
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func SetOIDCHeaders(req *http.Request) {
|
||||
fp := GlobalFingerprintManager().GetFingerprint("oidc-session")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion))
|
||||
req.Header.Set("User-Agent", fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/%s#%s m/E KiroIDE",
|
||||
fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, "sso-oidc", fp.OIDCSDKVersion))
|
||||
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=4")
|
||||
}
|
||||
|
||||
func setRuntimeHeaders(req *http.Request, accessToken string, accountKey string) {
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
machineID := fp.KiroHash
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.RuntimeSDKVersion, fp.KiroVersion, machineID))
|
||||
req.Header.Set("User-Agent", fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererruntime#%s m/N,E KiroIDE-%s-%s",
|
||||
fp.RuntimeSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.RuntimeSDKVersion,
|
||||
fp.KiroVersion, machineID))
|
||||
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package kiro
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
@@ -26,8 +28,14 @@ func TestGetFingerprint_NewToken(t *testing.T) {
|
||||
if fp == nil {
|
||||
t.Fatal("expected non-nil Fingerprint")
|
||||
}
|
||||
if fp.SDKVersion == "" {
|
||||
t.Error("expected non-empty SDKVersion")
|
||||
if fp.OIDCSDKVersion == "" {
|
||||
t.Error("expected non-empty OIDCSDKVersion")
|
||||
}
|
||||
if fp.RuntimeSDKVersion == "" {
|
||||
t.Error("expected non-empty RuntimeSDKVersion")
|
||||
}
|
||||
if fp.StreamingSDKVersion == "" {
|
||||
t.Error("expected non-empty StreamingSDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
@@ -44,18 +52,6 @@ func TestGetFingerprint_NewToken(t *testing.T) {
|
||||
if fp.KiroHash == "" {
|
||||
t.Error("expected non-empty KiroHash")
|
||||
}
|
||||
if fp.AcceptLanguage == "" {
|
||||
t.Error("expected non-empty AcceptLanguage")
|
||||
}
|
||||
if fp.ScreenResolution == "" {
|
||||
t.Error("expected non-empty ScreenResolution")
|
||||
}
|
||||
if fp.ColorDepth == 0 {
|
||||
t.Error("expected non-zero ColorDepth")
|
||||
}
|
||||
if fp.HardwareConcurrency == 0 {
|
||||
t.Error("expected non-zero HardwareConcurrency")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
||||
@@ -78,72 +74,18 @@ func TestGetFingerprint_DifferentTokens(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fm.GetFingerprint("token1")
|
||||
if fm.Count() != 1 {
|
||||
t.Fatalf("expected count 1, got %d", fm.Count())
|
||||
}
|
||||
|
||||
fm.RemoveFingerprint("token1")
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFingerprint_NonExistent(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fm.RemoveFingerprint("nonexistent")
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
|
||||
fm.GetFingerprint("token1")
|
||||
fm.GetFingerprint("token2")
|
||||
fm.GetFingerprint("token3")
|
||||
|
||||
if fm.Count() != 3 {
|
||||
t.Errorf("expected count 3, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyToRequest(t *testing.T) {
|
||||
func TestBuildUserAgent(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fp.ApplyToRequest(req)
|
||||
ua := fp.BuildUserAgent()
|
||||
if ua == "" {
|
||||
t.Error("expected non-empty User-Agent")
|
||||
}
|
||||
|
||||
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
|
||||
t.Error("X-Kiro-SDK-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
|
||||
t.Error("X-Kiro-OS-Type header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
|
||||
t.Error("X-Kiro-OS-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
|
||||
t.Error("X-Kiro-Node-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
|
||||
t.Error("X-Kiro-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
|
||||
t.Error("X-Kiro-Hash header mismatch")
|
||||
}
|
||||
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
|
||||
t.Error("Accept-Language header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
|
||||
t.Error("X-Screen-Resolution header mismatch")
|
||||
amzUA := fp.BuildAmzUserAgent()
|
||||
if amzUA == "" {
|
||||
t.Error("expected non-empty X-Amz-User-Agent")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,6 +108,33 @@ func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateFromConfig_OSTypeFromRuntimeGOOS(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Set config with empty OSType to trigger runtime.GOOS fallback
|
||||
fm.SetConfig(&FingerprintConfig{
|
||||
OIDCSDKVersion: "3.738.0", // Set other fields to use config path
|
||||
})
|
||||
|
||||
fp := fm.GetFingerprint("test-token")
|
||||
|
||||
// Expected OS type based on runtime.GOOS mapping
|
||||
var expectedOS string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
expectedOS = "darwin"
|
||||
case "windows":
|
||||
expectedOS = "windows"
|
||||
default:
|
||||
expectedOS = "linux"
|
||||
}
|
||||
|
||||
if fp.OSType != expectedOS {
|
||||
t.Errorf("expected OSType '%s' from runtime.GOOS '%s', got '%s'",
|
||||
expectedOS, runtime.GOOS, fp.OSType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
const numGoroutines = 100
|
||||
@@ -174,22 +143,18 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
for i := range numGoroutines {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
for j := range numOperations {
|
||||
tokenKey := "token" + string(rune('a'+id%26))
|
||||
switch j % 4 {
|
||||
switch j % 2 {
|
||||
case 0:
|
||||
fm.GetFingerprint(tokenKey)
|
||||
case 1:
|
||||
fm.Count()
|
||||
case 2:
|
||||
fp := fm.GetFingerprint(tokenKey)
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fp.ApplyToRequest(req)
|
||||
case 3:
|
||||
fm.RemoveFingerprint(tokenKey)
|
||||
_ = fp.BuildUserAgent()
|
||||
_ = fp.BuildAmzUserAgent()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
@@ -198,16 +163,20 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestKiroHashUniqueness(t *testing.T) {
|
||||
func TestKiroHashStability(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
hashes := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune(i)))
|
||||
if hashes[fp.KiroHash] {
|
||||
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
|
||||
}
|
||||
hashes[fp.KiroHash] = true
|
||||
// Same token should always return same hash
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
if fp1.KiroHash != fp2.KiroHash {
|
||||
t.Errorf("same token should have same hash: %s vs %s", fp1.KiroHash, fp2.KiroHash)
|
||||
}
|
||||
|
||||
// Different tokens should have different hashes
|
||||
fp3 := fm.GetFingerprint("token2")
|
||||
if fp1.KiroHash == fp3.KiroHash {
|
||||
t.Errorf("different tokens should have different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,8 +189,590 @@ func TestKiroHashFormat(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, c := range fp.KiroHash {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
|
||||
t.Errorf("invalid hex character in KiroHash: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalFingerprintManager(t *testing.T) {
|
||||
fm1 := GlobalFingerprintManager()
|
||||
fm2 := GlobalFingerprintManager()
|
||||
|
||||
if fm1 == nil {
|
||||
t.Fatal("expected non-nil GlobalFingerprintManager")
|
||||
}
|
||||
if fm1 != fm2 {
|
||||
t.Error("expected GlobalFingerprintManager to return same instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetOIDCHeaders(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
if req.Header.Get("Content-Type") != "application/json" {
|
||||
t.Error("expected Content-Type header to be set")
|
||||
}
|
||||
|
||||
amzUA := req.Header.Get("x-amz-user-agent")
|
||||
if amzUA == "" {
|
||||
t.Error("expected x-amz-user-agent header to be set")
|
||||
}
|
||||
if !strings.Contains(amzUA, "aws-sdk-js/") {
|
||||
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, "KiroIDE") {
|
||||
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
|
||||
}
|
||||
|
||||
ua := req.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
t.Error("expected User-Agent header to be set")
|
||||
}
|
||||
if !strings.Contains(ua, "api/sso-oidc") {
|
||||
t.Errorf("User-Agent should contain api name: %s", ua)
|
||||
}
|
||||
|
||||
if req.Header.Get("amz-sdk-invocation-id") == "" {
|
||||
t.Error("expected amz-sdk-invocation-id header to be set")
|
||||
}
|
||||
if req.Header.Get("amz-sdk-request") != "attempt=1; max=4" {
|
||||
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
path string
|
||||
queryParams map[string]string
|
||||
want string
|
||||
wantContains []string
|
||||
}{
|
||||
{
|
||||
name: "no query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: nil,
|
||||
want: "https://api.example.com/getUsageLimits",
|
||||
},
|
||||
{
|
||||
name: "empty query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{},
|
||||
want: "https://api.example.com/getUsageLimits",
|
||||
},
|
||||
{
|
||||
name: "single query param",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
},
|
||||
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
|
||||
},
|
||||
{
|
||||
name: "multiple query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
"profileArn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEF",
|
||||
},
|
||||
wantContains: []string{
|
||||
"https://api.example.com/getUsageLimits?",
|
||||
"origin=AI_EDITOR",
|
||||
"profileArn=arn%3Aaws%3Acodewhisperer%3Aus-east-1%3A123456789012%3Aprofile%2FABCDEF",
|
||||
"resourceType=AGENTIC_REQUEST",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "omit empty params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": "",
|
||||
},
|
||||
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildURL(tt.endpoint, tt.path, tt.queryParams)
|
||||
if tt.want != "" {
|
||||
if got != tt.want {
|
||||
t.Errorf("buildURL() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
if tt.wantContains != nil {
|
||||
for _, substr := range tt.wantContains {
|
||||
if !strings.Contains(got, substr) {
|
||||
t.Errorf("buildURL() = %v, want to contain %v", got, substr)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserAgentFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
ua := fp.BuildUserAgent()
|
||||
requiredParts := []string{
|
||||
"aws-sdk-js/",
|
||||
"ua/2.1",
|
||||
"os/",
|
||||
"lang/js",
|
||||
"md/nodejs#",
|
||||
"api/codewhispererstreaming#",
|
||||
"m/E",
|
||||
"KiroIDE-",
|
||||
}
|
||||
for _, part := range requiredParts {
|
||||
if !strings.Contains(ua, part) {
|
||||
t.Errorf("User-Agent missing required part %q: %s", part, ua)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAmzUserAgentFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
amzUA := fp.BuildAmzUserAgent()
|
||||
requiredParts := []string{
|
||||
"aws-sdk-js/",
|
||||
"KiroIDE-",
|
||||
}
|
||||
for _, part := range requiredParts {
|
||||
if !strings.Contains(amzUA, part) {
|
||||
t.Errorf("X-Amz-User-Agent missing required part %q: %s", part, amzUA)
|
||||
}
|
||||
}
|
||||
|
||||
// Amz-User-Agent should be shorter than User-Agent
|
||||
ua := fp.BuildUserAgent()
|
||||
if len(amzUA) >= len(ua) {
|
||||
t.Error("X-Amz-User-Agent should be shorter than User-Agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRuntimeHeaders(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
accessToken := "test-access-token-1234567890"
|
||||
clientID := "test-client-id-12345"
|
||||
accountKey := GenerateAccountKey(clientID)
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
machineID := fp.KiroHash
|
||||
|
||||
setRuntimeHeaders(req, accessToken, accountKey)
|
||||
|
||||
// Check Authorization header
|
||||
if req.Header.Get("Authorization") != "Bearer "+accessToken {
|
||||
t.Errorf("expected Authorization header 'Bearer %s', got '%s'", accessToken, req.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
// Check x-amz-user-agent header
|
||||
amzUA := req.Header.Get("x-amz-user-agent")
|
||||
if amzUA == "" {
|
||||
t.Error("expected x-amz-user-agent header to be set")
|
||||
}
|
||||
if !strings.Contains(amzUA, "aws-sdk-js/") {
|
||||
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, "KiroIDE-") {
|
||||
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, machineID) {
|
||||
t.Errorf("x-amz-user-agent should contain machineID: %s", amzUA)
|
||||
}
|
||||
|
||||
// Check User-Agent header
|
||||
ua := req.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
t.Error("expected User-Agent header to be set")
|
||||
}
|
||||
if !strings.Contains(ua, "api/codewhispererruntime#") {
|
||||
t.Errorf("User-Agent should contain api/codewhispererruntime: %s", ua)
|
||||
}
|
||||
if !strings.Contains(ua, "m/N,E") {
|
||||
t.Errorf("User-Agent should contain m/N,E: %s", ua)
|
||||
}
|
||||
|
||||
// Check amz-sdk-invocation-id (should be a UUID)
|
||||
invocationID := req.Header.Get("amz-sdk-invocation-id")
|
||||
if invocationID == "" {
|
||||
t.Error("expected amz-sdk-invocation-id header to be set")
|
||||
}
|
||||
if len(invocationID) != 36 {
|
||||
t.Errorf("expected amz-sdk-invocation-id to be UUID (36 chars), got %d", len(invocationID))
|
||||
}
|
||||
|
||||
// Check amz-sdk-request
|
||||
if req.Header.Get("amz-sdk-request") != "attempt=1; max=1" {
|
||||
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSDKVersionsAreValid(t *testing.T) {
|
||||
// Verify all OIDC SDK versions match expected format (3.xxx.x)
|
||||
for _, v := range oidcSDKVersions {
|
||||
if !strings.HasPrefix(v, "3.") {
|
||||
t.Errorf("OIDC SDK version should start with 3.: %s", v)
|
||||
}
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("OIDC SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range runtimeSDKVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Runtime SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range streamingSDKVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Streaming SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroVersionsAreValid(t *testing.T) {
|
||||
// Verify all Kiro versions match expected format (0.x.xxx)
|
||||
for _, v := range kiroVersions {
|
||||
if !strings.HasPrefix(v, "0.") {
|
||||
t.Errorf("Kiro version should start with 0.: %s", v)
|
||||
}
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Kiro version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeVersionsAreValid(t *testing.T) {
|
||||
// Verify all Node versions match expected format (xx.xx.x)
|
||||
for _, v := range nodeVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Node version should have 3 parts: %s", v)
|
||||
}
|
||||
// Should be Node 20.x or 22.x
|
||||
if !strings.HasPrefix(v, "20.") && !strings.HasPrefix(v, "22.") {
|
||||
t.Errorf("Node version should be 20.x or 22.x LTS: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Without config, should generate random fingerprint
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
if fp1 == nil {
|
||||
t.Fatal("expected non-nil fingerprint")
|
||||
}
|
||||
|
||||
// Set config with all fields
|
||||
cfg := &FingerprintConfig{
|
||||
OIDCSDKVersion: "3.999.0",
|
||||
RuntimeSDKVersion: "9.9.9",
|
||||
StreamingSDKVersion: "8.8.8",
|
||||
OSType: "darwin",
|
||||
OSVersion: "99.0.0",
|
||||
NodeVersion: "99.99.99",
|
||||
KiroVersion: "9.9.999",
|
||||
KiroHash: "customhash123",
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
// After setting config, should use config values
|
||||
fp2 := fm.GetFingerprint("token2")
|
||||
if fp2.OIDCSDKVersion != "3.999.0" {
|
||||
t.Errorf("expected OIDCSDKVersion '3.999.0', got '%s'", fp2.OIDCSDKVersion)
|
||||
}
|
||||
if fp2.RuntimeSDKVersion != "9.9.9" {
|
||||
t.Errorf("expected RuntimeSDKVersion '9.9.9', got '%s'", fp2.RuntimeSDKVersion)
|
||||
}
|
||||
if fp2.StreamingSDKVersion != "8.8.8" {
|
||||
t.Errorf("expected StreamingSDKVersion '8.8.8', got '%s'", fp2.StreamingSDKVersion)
|
||||
}
|
||||
if fp2.OSType != "darwin" {
|
||||
t.Errorf("expected OSType 'darwin', got '%s'", fp2.OSType)
|
||||
}
|
||||
if fp2.OSVersion != "99.0.0" {
|
||||
t.Errorf("expected OSVersion '99.0.0', got '%s'", fp2.OSVersion)
|
||||
}
|
||||
if fp2.NodeVersion != "99.99.99" {
|
||||
t.Errorf("expected NodeVersion '99.99.99', got '%s'", fp2.NodeVersion)
|
||||
}
|
||||
if fp2.KiroVersion != "9.9.999" {
|
||||
t.Errorf("expected KiroVersion '9.9.999', got '%s'", fp2.KiroVersion)
|
||||
}
|
||||
if fp2.KiroHash != "customhash123" {
|
||||
t.Errorf("expected KiroHash 'customhash123', got '%s'", fp2.KiroHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig_PartialFields(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Set config with only some fields
|
||||
cfg := &FingerprintConfig{
|
||||
KiroVersion: "1.2.345",
|
||||
KiroHash: "myhash",
|
||||
// Other fields empty - should use random
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
// Configured fields should use config values
|
||||
if fp.KiroVersion != "1.2.345" {
|
||||
t.Errorf("expected KiroVersion '1.2.345', got '%s'", fp.KiroVersion)
|
||||
}
|
||||
if fp.KiroHash != "myhash" {
|
||||
t.Errorf("expected KiroHash 'myhash', got '%s'", fp.KiroHash)
|
||||
}
|
||||
|
||||
// Empty fields should be randomly selected (non-empty)
|
||||
if fp.OIDCSDKVersion == "" {
|
||||
t.Error("expected non-empty OIDCSDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
}
|
||||
if fp.NodeVersion == "" {
|
||||
t.Error("expected non-empty NodeVersion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig_ClearsCache(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Get fingerprint before config
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
originalHash := fp1.KiroHash
|
||||
|
||||
// Set config
|
||||
cfg := &FingerprintConfig{
|
||||
KiroHash: "newcustomhash",
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
// Same token should now return different fingerprint (cache cleared)
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
if fp2.KiroHash == originalHash {
|
||||
t.Error("expected cache to be cleared after SetConfig")
|
||||
}
|
||||
if fp2.KiroHash != "newcustomhash" {
|
||||
t.Errorf("expected KiroHash 'newcustomhash', got '%s'", fp2.KiroHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAccountKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
seed string
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "Empty seed",
|
||||
seed: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result for empty seed")
|
||||
}
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char hex string, got %d chars", len(result))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Simple seed",
|
||||
seed: "test-client-id",
|
||||
check: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char hex string, got %d chars", len(result))
|
||||
}
|
||||
// Verify it's valid hex
|
||||
for _, c := range result {
|
||||
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
|
||||
t.Errorf("invalid hex character: %c", c)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Same seed produces same result",
|
||||
seed: "deterministic-seed",
|
||||
check: func(t *testing.T, result string) {
|
||||
result2 := GenerateAccountKey("deterministic-seed")
|
||||
if result != result2 {
|
||||
t.Errorf("same seed should produce same result: %s vs %s", result, result2)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Different seeds produce different results",
|
||||
seed: "seed-one",
|
||||
check: func(t *testing.T, result string) {
|
||||
result2 := GenerateAccountKey("seed-two")
|
||||
if result == result2 {
|
||||
t.Errorf("different seeds should produce different results: %s vs %s", result, result2)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GenerateAccountKey(tt.seed)
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
refreshToken string
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "Priority 1: clientID when both provided",
|
||||
clientID: "client-id-123",
|
||||
refreshToken: "refresh-token-456",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("client-id-123")
|
||||
if result != expected {
|
||||
t.Errorf("expected clientID-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Priority 2: refreshToken when clientID is empty",
|
||||
clientID: "",
|
||||
refreshToken: "refresh-token-789",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("refresh-token-789")
|
||||
if result != expected {
|
||||
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Priority 3: random when both empty",
|
||||
clientID: "",
|
||||
refreshToken: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||
}
|
||||
// Should be different each time (random UUID)
|
||||
result2 := GetAccountKey("", "")
|
||||
if result == result2 {
|
||||
t.Log("warning: random keys are the same (possible but unlikely)")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "clientID only",
|
||||
clientID: "solo-client-id",
|
||||
refreshToken: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("solo-client-id")
|
||||
if result != expected {
|
||||
t.Errorf("expected clientID-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "refreshToken only",
|
||||
clientID: "",
|
||||
refreshToken: "solo-refresh-token",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("solo-refresh-token")
|
||||
if result != expected {
|
||||
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetAccountKey(tt.clientID, tt.refreshToken)
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountKey_Deterministic(t *testing.T) {
|
||||
// Verify that GetAccountKey produces deterministic results for same inputs
|
||||
clientID := "test-client-id-abc"
|
||||
refreshToken := "test-refresh-token-xyz"
|
||||
|
||||
// Call multiple times with same inputs
|
||||
results := make([]string, 10)
|
||||
for i := range 10 {
|
||||
results[i] = GetAccountKey(clientID, refreshToken)
|
||||
}
|
||||
|
||||
// All results should be identical
|
||||
for i := 1; i < 10; i++ {
|
||||
if results[i] != results[0] {
|
||||
t.Errorf("GetAccountKey should be deterministic: got %s and %s", results[0], results[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintDeterministic(t *testing.T) {
|
||||
// Verify that fingerprints are deterministic based on accountKey
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
accountKey := GenerateAccountKey("test-client-id")
|
||||
|
||||
// Get fingerprint multiple times
|
||||
fp1 := fm.GetFingerprint(accountKey)
|
||||
fp2 := fm.GetFingerprint(accountKey)
|
||||
|
||||
// Should be the same pointer (cached)
|
||||
if fp1 != fp2 {
|
||||
t.Error("expected same fingerprint pointer for same key")
|
||||
}
|
||||
|
||||
// Create new manager and verify same values
|
||||
fm2 := NewFingerprintManager()
|
||||
fp3 := fm2.GetFingerprint(accountKey)
|
||||
|
||||
// Values should be identical (deterministic generation)
|
||||
if fp1.KiroHash != fp3.KiroHash {
|
||||
t.Errorf("KiroHash should be deterministic: %s vs %s", fp1.KiroHash, fp3.KiroHash)
|
||||
}
|
||||
if fp1.OSType != fp3.OSType {
|
||||
t.Errorf("OSType should be deterministic: %s vs %s", fp1.OSType, fp3.OSType)
|
||||
}
|
||||
if fp1.OSVersion != fp3.OSVersion {
|
||||
t.Errorf("OSVersion should be deterministic: %s vs %s", fp1.OSVersion, fp3.OSVersion)
|
||||
}
|
||||
if fp1.KiroVersion != fp3.KiroVersion {
|
||||
t.Errorf("KiroVersion should be deterministic: %s vs %s", fp1.KiroVersion, fp3.KiroVersion)
|
||||
}
|
||||
if fp1.NodeVersion != fp3.NodeVersion {
|
||||
t.Errorf("NodeVersion should be deterministic: %s vs %s", fp1.NodeVersion, fp3.NodeVersion)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,10 @@ import (
|
||||
const (
|
||||
// Kiro auth endpoint
|
||||
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||
|
||||
|
||||
// Default callback port
|
||||
defaultCallbackPort = 9876
|
||||
|
||||
|
||||
// Auth timeout
|
||||
authTimeout = 10 * time.Minute
|
||||
)
|
||||
@@ -41,8 +41,10 @@ type KiroTokenResponse struct {
|
||||
|
||||
// KiroOAuth handles the OAuth flow for Kiro authentication.
|
||||
type KiroOAuth struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
machineID string
|
||||
kiroVersion string
|
||||
}
|
||||
|
||||
// NewKiroOAuth creates a new Kiro OAuth handler.
|
||||
@@ -51,9 +53,12 @@ func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
fp := GlobalFingerprintManager().GetFingerprint("login")
|
||||
return &KiroOAuth{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
machineID: fp.KiroHash,
|
||||
kiroVersion: fp.KiroVersion,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +195,8 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -256,11 +262,8 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Use KiroIDE-style User-Agent to match official Kiro IDE behavior
|
||||
// This helps avoid 403 errors from server-side User-Agent validation
|
||||
userAgent := buildKiroUserAgent(tokenKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -301,19 +304,6 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildKiroUserAgent builds a KiroIDE-style User-Agent string.
|
||||
// If tokenKey is provided, uses fingerprint manager for consistent fingerprint.
|
||||
// Otherwise generates a simple KiroIDE User-Agent.
|
||||
func buildKiroUserAgent(tokenKey string) string {
|
||||
if tokenKey != "" {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint(tokenKey)
|
||||
return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16])
|
||||
}
|
||||
// Default KiroIDE User-Agent matching kiro-openai-gateway format
|
||||
return "KiroIDE-0.7.45-cli-proxy-api"
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||
|
||||
@@ -35,35 +35,35 @@ const (
|
||||
)
|
||||
|
||||
type webAuthSession struct {
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
authMethod string // "google", "github", "builder-id", "idc"
|
||||
startURL string // Used for IDC
|
||||
codeVerifier string // Used for social auth PKCE
|
||||
codeChallenge string // Used for social auth PKCE
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
authMethod string // "google", "github", "builder-id", "idc"
|
||||
startURL string // Used for IDC
|
||||
codeVerifier string // Used for social auth PKCE
|
||||
codeChallenge string // Used for social auth PKCE
|
||||
}
|
||||
|
||||
type OAuthWebHandler struct {
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
}
|
||||
|
||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||
@@ -104,7 +104,7 @@ func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
|
||||
|
||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||
method := c.Query("method")
|
||||
|
||||
|
||||
if method == "" {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
||||
return
|
||||
@@ -138,7 +138,7 @@ func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
|
||||
var provider string
|
||||
if method == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
@@ -373,22 +373,28 @@ func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSess
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
||||
|
||||
// Fetch profileArn for IDC
|
||||
var profileArn string
|
||||
if session.authMethod == "idc" {
|
||||
profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
|
||||
}
|
||||
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
Region: session.region,
|
||||
StartURL: session.startURL,
|
||||
}
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
Region: session.region,
|
||||
StartURL: session.startURL,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
@@ -442,7 +448,7 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
||||
fileName := GenerateTokenFileName(tokenData)
|
||||
|
||||
authFilePath := filepath.Join(authDir, fileName)
|
||||
|
||||
|
||||
// Convert to storage format and save
|
||||
storage := &KiroTokenStorage{
|
||||
Type: "kiro",
|
||||
@@ -459,12 +465,12 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
|
||||
|
||||
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
||||
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RefreshManager 是后台刷新器的单例管理器
|
||||
// RefreshManager is a singleton manager for background token refreshing.
|
||||
type RefreshManager struct {
|
||||
mu sync.Mutex
|
||||
refresher *BackgroundRefresher
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
started bool
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -25,7 +25,7 @@ var (
|
||||
managerOnce sync.Once
|
||||
)
|
||||
|
||||
// GetRefreshManager 获取全局刷新管理器实例
|
||||
// GetRefreshManager returns the global RefreshManager singleton.
|
||||
func GetRefreshManager() *RefreshManager {
|
||||
managerOnce.Do(func() {
|
||||
globalRefreshManager = &RefreshManager{}
|
||||
@@ -33,9 +33,7 @@ func GetRefreshManager() *RefreshManager {
|
||||
return globalRefreshManager
|
||||
}
|
||||
|
||||
// Initialize 初始化后台刷新器
|
||||
// baseDir: token 文件所在的目录
|
||||
// cfg: 应用配置
|
||||
// Initialize sets up the background refresher.
|
||||
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -58,18 +56,16 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
||||
baseDir = resolvedBaseDir
|
||||
}
|
||||
|
||||
// 创建 token 存储库
|
||||
repo := NewFileTokenRepository(baseDir)
|
||||
|
||||
// 创建后台刷新器,配置参数
|
||||
opts := []RefresherOption{
|
||||
WithInterval(time.Minute), // 每分钟检查一次
|
||||
WithBatchSize(50), // 每批最多处理 50 个 token
|
||||
WithConcurrency(10), // 最多 10 个并发刷新
|
||||
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
|
||||
WithInterval(time.Minute),
|
||||
WithBatchSize(50),
|
||||
WithConcurrency(10),
|
||||
WithConfig(cfg),
|
||||
}
|
||||
|
||||
// 如果已设置回调,传递给 BackgroundRefresher
|
||||
// Pass callback to BackgroundRefresher if already set
|
||||
if m.onTokenRefreshed != nil {
|
||||
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
|
||||
}
|
||||
@@ -80,7 +76,7 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动后台刷新
|
||||
// Start begins background token refreshing.
|
||||
func (m *RefreshManager) Start() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -102,7 +98,7 @@ func (m *RefreshManager) Start() {
|
||||
log.Info("refresh manager: background refresh started")
|
||||
}
|
||||
|
||||
// Stop 停止后台刷新
|
||||
// Stop halts background token refreshing.
|
||||
func (m *RefreshManager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -123,14 +119,14 @@ func (m *RefreshManager) Stop() {
|
||||
log.Info("refresh manager: background refresh stopped")
|
||||
}
|
||||
|
||||
// IsRunning 检查后台刷新是否正在运行
|
||||
// IsRunning reports whether background refreshing is active.
|
||||
func (m *RefreshManager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.started
|
||||
}
|
||||
|
||||
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
|
||||
// UpdateBaseDir changes the token directory at runtime.
|
||||
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -143,16 +139,15 @@ func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
|
||||
// 可以在任何时候调用,支持运行时更新回调
|
||||
// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据
|
||||
// SetOnTokenRefreshed registers a callback invoked after a successful token refresh.
|
||||
// Can be called at any time; supports runtime callback updates.
|
||||
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.onTokenRefreshed = callback
|
||||
|
||||
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
|
||||
// Update the refresher's callback in a thread-safe manner if already created
|
||||
if m.refresher != nil {
|
||||
m.refresher.callbackMu.Lock()
|
||||
m.refresher.onTokenRefreshed = callback
|
||||
@@ -162,8 +157,11 @@ func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, token
|
||||
log.Debug("refresh manager: token refresh callback registered")
|
||||
}
|
||||
|
||||
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
|
||||
// InitializeAndStart initializes and starts background refreshing (convenience method).
|
||||
func InitializeAndStart(baseDir string, cfg *config.Config) {
|
||||
// Initialize global fingerprint config
|
||||
initGlobalFingerprintConfig(cfg)
|
||||
|
||||
manager := GetRefreshManager()
|
||||
if err := manager.Initialize(baseDir, cfg); err != nil {
|
||||
log.Errorf("refresh manager: initialization failed: %v", err)
|
||||
@@ -172,7 +170,31 @@ func InitializeAndStart(baseDir string, cfg *config.Config) {
|
||||
manager.Start()
|
||||
}
|
||||
|
||||
// StopGlobalRefreshManager 停止全局刷新管理器
|
||||
// initGlobalFingerprintConfig loads fingerprint settings from application config.
|
||||
func initGlobalFingerprintConfig(cfg *config.Config) {
|
||||
if cfg == nil || cfg.KiroFingerprint == nil {
|
||||
return
|
||||
}
|
||||
fpCfg := cfg.KiroFingerprint
|
||||
SetGlobalFingerprintConfig(&FingerprintConfig{
|
||||
OIDCSDKVersion: fpCfg.OIDCSDKVersion,
|
||||
RuntimeSDKVersion: fpCfg.RuntimeSDKVersion,
|
||||
StreamingSDKVersion: fpCfg.StreamingSDKVersion,
|
||||
OSType: fpCfg.OSType,
|
||||
OSVersion: fpCfg.OSVersion,
|
||||
NodeVersion: fpCfg.NodeVersion,
|
||||
KiroVersion: fpCfg.KiroVersion,
|
||||
KiroHash: fpCfg.KiroHash,
|
||||
})
|
||||
log.Debug("kiro: global fingerprint config loaded")
|
||||
}
|
||||
|
||||
// InitFingerprintConfig initializes the global fingerprint config from application config.
|
||||
func InitFingerprintConfig(cfg *config.Config) {
|
||||
initGlobalFingerprintConfig(cfg)
|
||||
}
|
||||
|
||||
// StopGlobalRefreshManager stops the global refresh manager.
|
||||
func StopGlobalRefreshManager() {
|
||||
if globalRefreshManager != nil {
|
||||
globalRefreshManager.Stop()
|
||||
|
||||
@@ -84,6 +84,8 @@ type SocialAuthClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
protocolHandler *ProtocolHandler
|
||||
machineID string
|
||||
kiroVersion string
|
||||
}
|
||||
|
||||
// NewSocialAuthClient creates a new social auth client.
|
||||
@@ -92,10 +94,13 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
fp := GlobalFingerprintManager().GetFingerprint("login")
|
||||
return &SocialAuthClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
protocolHandler: NewProtocolHandler(),
|
||||
machineID: fp.KiroHash,
|
||||
kiroVersion: fp.KiroVersion,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,7 +234,8 @@ func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequ
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
|
||||
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
|
||||
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
@@ -269,7 +275,8 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
|
||||
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
@@ -466,7 +473,7 @@ func forceDefaultProtocolHandler() {
|
||||
if runtime.GOOS != "linux" {
|
||||
return // Non-Linux platforms use different handler mechanisms
|
||||
}
|
||||
|
||||
|
||||
// Set our handler as default using xdg-mime
|
||||
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
||||
if err := cmd.Run(); err != nil {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -40,21 +41,13 @@ const (
|
||||
// Authorization code flow callback
|
||||
authCodeCallbackPath = "/oauth/callback"
|
||||
authCodeCallbackPort = 19877
|
||||
|
||||
// User-Agent to match official Kiro IDE
|
||||
kiroUserAgent = "KiroIDE"
|
||||
|
||||
// IDC token refresh headers (matching Kiro IDE behavior)
|
||||
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
|
||||
)
|
||||
|
||||
// Sentinel errors for OIDC token polling
|
||||
var (
|
||||
ErrAuthorizationPending = errors.New("authorization_pending")
|
||||
ErrSlowDown = errors.New("slow_down")
|
||||
)
|
||||
|
||||
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
||||
type SSOOIDCClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
@@ -74,10 +67,10 @@ func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient {
|
||||
|
||||
// RegisterClientResponse from AWS SSO OIDC.
|
||||
type RegisterClientResponse struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
|
||||
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
|
||||
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
|
||||
}
|
||||
|
||||
// StartDeviceAuthResponse from AWS SSO OIDC.
|
||||
@@ -174,8 +167,7 @@ func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region str
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -220,8 +212,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, cli
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -267,8 +258,7 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -311,8 +301,11 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
|
||||
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific OIDC region.
|
||||
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]string{
|
||||
@@ -331,18 +324,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set headers matching kiro2api's IDC token refresh
|
||||
// These headers are required for successful IDC token refresh
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Accept-Language", "*")
|
||||
req.Header.Set("sec-fetch-mode", "cors")
|
||||
req.Header.Set("User-Agent", "node")
|
||||
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -469,10 +451,10 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
|
||||
|
||||
// Step 5: Get profile ARN from CodeWhisperer API
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||
|
||||
// Fetch user email
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
@@ -502,12 +484,36 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
|
||||
// IDCLoginOptions holds optional parameters for IDC login.
|
||||
type IDCLoginOptions struct {
|
||||
StartURL string // Pre-configured start URL (skips prompt if set)
|
||||
Region string // OIDC region for login and token refresh (defaults to us-east-1)
|
||||
UseDeviceCode bool // Use Device Code flow instead of Auth Code flow
|
||||
}
|
||||
|
||||
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
|
||||
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
|
||||
// Options can be provided to pre-configure IDC parameters (startURL, region).
|
||||
// If StartURL is provided in opts, IDC flow is used directly without prompting.
|
||||
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context, opts *IDCLoginOptions) (*KiroTokenData, error) {
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Authentication (AWS) ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// If IDC options with StartURL are provided, skip method selection and use IDC directly
|
||||
if opts != nil && opts.StartURL != "" {
|
||||
region := opts.Region
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
fmt.Printf("\n Using IDC with Start URL: %s\n", opts.StartURL)
|
||||
fmt.Printf(" Region: %s\n", region)
|
||||
|
||||
if opts.UseDeviceCode {
|
||||
return c.LoginWithIDCAndOptions(ctx, opts.StartURL, region)
|
||||
}
|
||||
return c.LoginWithIDCAuthCode(ctx, opts.StartURL, region)
|
||||
}
|
||||
|
||||
// Prompt for login method
|
||||
options := []string{
|
||||
"Use with Builder ID (personal AWS account)",
|
||||
@@ -520,15 +526,41 @@ func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroToke
|
||||
return c.LoginWithBuilderID(ctx)
|
||||
}
|
||||
|
||||
// IDC flow - prompt for start URL and region
|
||||
fmt.Println()
|
||||
startURL := promptInput("? Enter Start URL", "")
|
||||
if startURL == "" {
|
||||
return nil, fmt.Errorf("start URL is required for IDC login")
|
||||
// IDC flow - use pre-configured values or prompt
|
||||
var startURL, region string
|
||||
|
||||
if opts != nil {
|
||||
startURL = opts.StartURL
|
||||
region = opts.Region
|
||||
}
|
||||
|
||||
region := promptInput("? Enter Region", defaultIDCRegion)
|
||||
fmt.Println()
|
||||
|
||||
// Use pre-configured startURL or prompt
|
||||
if startURL == "" {
|
||||
startURL = promptInput("? Enter Start URL", "")
|
||||
if startURL == "" {
|
||||
return nil, fmt.Errorf("start URL is required for IDC login")
|
||||
}
|
||||
} else {
|
||||
fmt.Printf(" Using pre-configured Start URL: %s\n", startURL)
|
||||
}
|
||||
|
||||
// Use pre-configured region or prompt
|
||||
if region == "" {
|
||||
region = promptInput("? Enter Region", defaultIDCRegion)
|
||||
} else {
|
||||
fmt.Printf(" Using pre-configured Region: %s\n", region)
|
||||
}
|
||||
|
||||
if opts != nil && opts.UseDeviceCode {
|
||||
return c.LoginWithIDCAndOptions(ctx, startURL, region)
|
||||
}
|
||||
return c.LoginWithIDCAuthCode(ctx, startURL, region)
|
||||
}
|
||||
|
||||
// LoginWithIDCAndOptions performs IDC login with the specified region.
|
||||
func (c *SSOOIDCClient) LoginWithIDCAndOptions(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
|
||||
return c.LoginWithIDC(ctx, startURL, region)
|
||||
}
|
||||
|
||||
@@ -550,8 +582,7 @@ func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResp
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -594,8 +625,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -639,8 +669,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -702,13 +731,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set headers matching Kiro IDE behavior for better compatibility
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Host", "oidc.us-east-1.amazonaws.com")
|
||||
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
|
||||
req.Header.Set("User-Agent", "node")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -835,12 +858,8 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Step 5: Get profile ARN from CodeWhisperer API
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
|
||||
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
@@ -850,7 +869,7 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ProfileArn: "", // Builder ID has no profile
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "builder-id",
|
||||
Provider: "AWS",
|
||||
@@ -859,15 +878,15 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close browser on timeout for better UX
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
// Close browser on timeout for better UX
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
|
||||
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
||||
// Falls back to JWT parsing if userinfo fails.
|
||||
@@ -931,20 +950,64 @@ func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken str
|
||||
return ""
|
||||
}
|
||||
|
||||
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
|
||||
// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
|
||||
func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
|
||||
// Try ListProfiles API first
|
||||
profileArn := c.tryListProfiles(ctx, accessToken)
|
||||
// FetchProfileArn fetches the profile ARN from ListAvailableProfiles API.
|
||||
// This is used to get profileArn for imported accounts that may not have it.
|
||||
func (c *SSOOIDCClient) FetchProfileArn(ctx context.Context, accessToken, clientID, refreshToken string) string {
|
||||
profileArn := c.tryListAvailableProfiles(ctx, accessToken, clientID, refreshToken)
|
||||
if profileArn != "" {
|
||||
return profileArn
|
||||
}
|
||||
|
||||
// Fallback: Try ListAvailableCustomizations
|
||||
return c.tryListCustomizations(ctx, accessToken)
|
||||
return c.tryListProfilesLegacy(ctx, accessToken)
|
||||
}
|
||||
|
||||
func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string {
|
||||
func (c *SSOOIDCClient) tryListAvailableProfiles(ctx context.Context, accessToken, clientID, refreshToken string) string {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetKiroAPIEndpoint("")+"/ListAvailableProfiles", strings.NewReader("{}"))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
accountKey := GetAccountKey(clientID, refreshToken)
|
||||
setRuntimeHeaders(req, accessToken, accountKey)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("ListAvailableProfiles request failed: %v", err)
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("ListAvailableProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Debugf("ListAvailableProfiles response: %s", string(respBody))
|
||||
|
||||
var result struct {
|
||||
Profiles []struct {
|
||||
Arn string `json:"arn"`
|
||||
ProfileName string `json:"profileName"`
|
||||
} `json:"profiles"`
|
||||
NextToken *string `json:"nextToken"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
log.Debugf("ListAvailableProfiles parse error: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(result.Profiles) > 0 {
|
||||
log.Debugf("Found profile: %s (%s)", result.Profiles[0].ProfileName, result.Profiles[0].Arn)
|
||||
return result.Profiles[0].Arn
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *SSOOIDCClient) tryListProfilesLegacy(ctx context.Context, accessToken string) string {
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
}
|
||||
@@ -954,7 +1017,9 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
|
||||
return ""
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
|
||||
// Use the legacy CodeWhisperer endpoint for JSON-RPC style requests.
|
||||
// The Q endpoint (q.{region}.amazonaws.com) does NOT support x-amz-target headers.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetCodeWhispererLegacyEndpoint(""), strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
@@ -973,11 +1038,11 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
log.Debugf("ListProfiles (legacy) failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Debugf("ListProfiles response: %s", string(respBody))
|
||||
log.Debugf("ListProfiles (legacy) response: %s", string(respBody))
|
||||
|
||||
var result struct {
|
||||
Profiles []struct {
|
||||
@@ -1001,63 +1066,6 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string {
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Debugf("ListAvailableCustomizations response: %s", string(respBody))
|
||||
|
||||
var result struct {
|
||||
Customizations []struct {
|
||||
Arn string `json:"arn"`
|
||||
} `json:"customizations"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if result.ProfileArn != "" {
|
||||
return result.ProfileArn
|
||||
}
|
||||
|
||||
if len(result.Customizations) > 0 {
|
||||
return result.Customizations[0].Arn
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// RegisterClientForAuthCode registers a new OIDC client for authorization code flow.
|
||||
func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) {
|
||||
payload := map[string]interface{}{
|
||||
@@ -1078,8 +1086,7 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -1105,6 +1112,53 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *SSOOIDCClient) RegisterClientForAuthCodeWithIDC(ctx context.Context, redirectURI, issuerUrl, region string) (*RegisterClientResponse, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"authorization_code", "refresh_token"},
|
||||
"redirectUris": []string{redirectURI},
|
||||
"issuerUrl": issuerUrl,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("register client for auth code with IDC failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterClientResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// AuthCodeCallbackResult contains the result from authorization code callback.
|
||||
type AuthCodeCallbackResult struct {
|
||||
Code string
|
||||
@@ -1128,6 +1182,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath)
|
||||
resultChan := make(chan AuthCodeCallbackResult, 1)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
@@ -1147,6 +1202,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Error: %s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- AuthCodeCallbackResult{Error: errParam}
|
||||
close(doneChan)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1156,6 +1212,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- AuthCodeCallbackResult{Error: "state mismatch"}
|
||||
close(doneChan)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1164,6 +1221,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
||||
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||
<script>window.close();</script></body></html>`)
|
||||
resultChan <- AuthCodeCallbackResult{Code: code, State: state}
|
||||
close(doneChan)
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
@@ -1178,7 +1236,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(10 * time.Minute):
|
||||
case <-resultChan:
|
||||
case <-doneChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
@@ -1227,8 +1285,54 @@ func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, c
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result CreateTokenResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *SSOOIDCClient) CreateTokenWithAuthCodeAndRegion(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI, region string) (*CreateTokenResponse, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"code": code,
|
||||
"codeVerifier": codeVerifier,
|
||||
"redirectUri": redirectURI,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -1352,12 +1456,118 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Step 8: Get profile ARN
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
|
||||
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: "", // Builder ID has no profile
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "builder-id",
|
||||
Provider: "AWS",
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SSOOIDCClient) LoginWithIDCAuthCode(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Authentication (AWS IDC - Auth Code) ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
|
||||
codeVerifier, codeChallenge, err := generatePKCEForAuthCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
state, err := generateStateForAuthCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\nStarting callback server...")
|
||||
redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
log.Debugf("Callback server started, redirect URI: %s", redirectURI)
|
||||
|
||||
fmt.Println("Registering client...")
|
||||
regResp, err := c.RegisterClientForAuthCodeWithIDC(ctx, redirectURI, startURL, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
|
||||
authURL := buildAuthorizationURL(endpoint, regResp.ClientID, redirectURI, scopes, state, codeChallenge)
|
||||
|
||||
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||
fmt.Println(" Opening browser for authentication...")
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
fmt.Printf("\n URL: %s\n\n", authURL)
|
||||
|
||||
if c.cfg != nil {
|
||||
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||
} else {
|
||||
browser.SetIncognitoMode(true)
|
||||
}
|
||||
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Could not open browser automatically: %v", err)
|
||||
fmt.Println(" ⚠ Could not open browser automatically.")
|
||||
fmt.Println(" Please open the URL above in your browser manually.")
|
||||
} else {
|
||||
fmt.Println(" (Browser opened automatically)")
|
||||
}
|
||||
|
||||
fmt.Println("\n Waiting for authorization callback...")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
browser.CloseBrowser()
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(10 * time.Minute):
|
||||
browser.CloseBrowser()
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
case result := <-resultChan:
|
||||
if result.Error != "" {
|
||||
browser.CloseBrowser()
|
||||
return nil, fmt.Errorf("authorization failed: %s", result.Error)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
tokenResp, err := c.CreateTokenWithAuthCodeAndRegion(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
@@ -1369,12 +1579,25 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "builder-id",
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildAuthorizationURL(endpoint, clientID, redirectURI, scopes, state, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scopes", scopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
return fmt.Sprintf("%s/authorize?%s", endpoint, params.Encode())
|
||||
}
|
||||
|
||||
261
internal/auth/kiro/sso_oidc_test.go
Normal file
261
internal/auth/kiro/sso_oidc_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type recordingRoundTripper struct {
|
||||
lastReq *http.Request
|
||||
}
|
||||
|
||||
func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rt.lastReq = req
|
||||
body := `{"nextToken":null,"profiles":[{"arn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC","profileName":"test"}]}`
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestTryListAvailableProfiles_UsesClientIDForAccountKey(t *testing.T) {
|
||||
rt := &recordingRoundTripper{}
|
||||
client := &SSOOIDCClient{
|
||||
httpClient: &http.Client{Transport: rt},
|
||||
}
|
||||
|
||||
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "client-id-123", "refresh-token-456")
|
||||
if profileArn == "" {
|
||||
t.Fatal("expected profileArn, got empty result")
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey("client-id-123", "refresh-token-456")
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
|
||||
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
|
||||
if got != expected {
|
||||
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryListAvailableProfiles_UsesRefreshTokenWhenClientIDMissing(t *testing.T) {
|
||||
rt := &recordingRoundTripper{}
|
||||
client := &SSOOIDCClient{
|
||||
httpClient: &http.Client{Transport: rt},
|
||||
}
|
||||
|
||||
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "", "refresh-token-789")
|
||||
if profileArn == "" {
|
||||
t.Fatal("expected profileArn, got empty result")
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey("", "refresh-token-789")
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
|
||||
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
|
||||
if got != expected {
|
||||
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterClientForAuthCodeWithIDC(t *testing.T) {
|
||||
var capturedReq struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers http.Header
|
||||
Body map[string]interface{}
|
||||
}
|
||||
|
||||
mockResp := RegisterClientResponse{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
ClientIDIssuedAt: 1700000000,
|
||||
ClientSecretExpiresAt: 1700086400,
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedReq.Method = r.Method
|
||||
capturedReq.Path = r.URL.Path
|
||||
capturedReq.Headers = r.Header.Clone()
|
||||
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
json.Unmarshal(bodyBytes, &capturedReq.Body)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(mockResp)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
// Extract host to build a region that resolves to our test server.
|
||||
// Override getOIDCEndpoint by passing region="" and patching the endpoint.
|
||||
// Since getOIDCEndpoint builds "https://oidc.{region}.amazonaws.com", we
|
||||
// instead inject the test server URL directly via a custom HTTP client transport.
|
||||
client := &SSOOIDCClient{
|
||||
httpClient: ts.Client(),
|
||||
}
|
||||
|
||||
// We need to route the request to our test server. Use a transport that rewrites the URL.
|
||||
client.httpClient.Transport = &rewriteTransport{
|
||||
base: ts.Client().Transport,
|
||||
targetURL: ts.URL,
|
||||
}
|
||||
|
||||
resp, err := client.RegisterClientForAuthCodeWithIDC(
|
||||
context.Background(),
|
||||
"http://127.0.0.1:19877/oauth/callback",
|
||||
"https://my-idc-instance.awsapps.com/start",
|
||||
"us-east-1",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify request method and path
|
||||
if capturedReq.Method != http.MethodPost {
|
||||
t.Errorf("method = %q, want POST", capturedReq.Method)
|
||||
}
|
||||
if capturedReq.Path != "/client/register" {
|
||||
t.Errorf("path = %q, want /client/register", capturedReq.Path)
|
||||
}
|
||||
|
||||
// Verify headers
|
||||
if ct := capturedReq.Headers.Get("Content-Type"); ct != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", ct)
|
||||
}
|
||||
ua := capturedReq.Headers.Get("User-Agent")
|
||||
if !strings.Contains(ua, "KiroIDE") {
|
||||
t.Errorf("User-Agent %q does not contain KiroIDE", ua)
|
||||
}
|
||||
if !strings.Contains(ua, "sso-oidc") {
|
||||
t.Errorf("User-Agent %q does not contain sso-oidc", ua)
|
||||
}
|
||||
xua := capturedReq.Headers.Get("X-Amz-User-Agent")
|
||||
if !strings.Contains(xua, "KiroIDE") {
|
||||
t.Errorf("x-amz-user-agent %q does not contain KiroIDE", xua)
|
||||
}
|
||||
|
||||
// Verify body fields
|
||||
if v, _ := capturedReq.Body["clientName"].(string); v != "Kiro IDE" {
|
||||
t.Errorf("clientName = %q, want %q", v, "Kiro IDE")
|
||||
}
|
||||
if v, _ := capturedReq.Body["clientType"].(string); v != "public" {
|
||||
t.Errorf("clientType = %q, want %q", v, "public")
|
||||
}
|
||||
if v, _ := capturedReq.Body["issuerUrl"].(string); v != "https://my-idc-instance.awsapps.com/start" {
|
||||
t.Errorf("issuerUrl = %q, want %q", v, "https://my-idc-instance.awsapps.com/start")
|
||||
}
|
||||
|
||||
// Verify scopes array
|
||||
scopesRaw, ok := capturedReq.Body["scopes"].([]interface{})
|
||||
if !ok || len(scopesRaw) != 5 {
|
||||
t.Fatalf("scopes: got %v, want 5-element array", capturedReq.Body["scopes"])
|
||||
}
|
||||
expectedScopes := []string{
|
||||
"codewhisperer:completions", "codewhisperer:analysis",
|
||||
"codewhisperer:conversations", "codewhisperer:transformations",
|
||||
"codewhisperer:taskassist",
|
||||
}
|
||||
for i, s := range expectedScopes {
|
||||
if scopesRaw[i].(string) != s {
|
||||
t.Errorf("scopes[%d] = %q, want %q", i, scopesRaw[i], s)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify grantTypes
|
||||
grantTypesRaw, ok := capturedReq.Body["grantTypes"].([]interface{})
|
||||
if !ok || len(grantTypesRaw) != 2 {
|
||||
t.Fatalf("grantTypes: got %v, want 2-element array", capturedReq.Body["grantTypes"])
|
||||
}
|
||||
if grantTypesRaw[0].(string) != "authorization_code" || grantTypesRaw[1].(string) != "refresh_token" {
|
||||
t.Errorf("grantTypes = %v, want [authorization_code, refresh_token]", grantTypesRaw)
|
||||
}
|
||||
|
||||
// Verify redirectUris
|
||||
redirectRaw, ok := capturedReq.Body["redirectUris"].([]interface{})
|
||||
if !ok || len(redirectRaw) != 1 {
|
||||
t.Fatalf("redirectUris: got %v, want 1-element array", capturedReq.Body["redirectUris"])
|
||||
}
|
||||
if redirectRaw[0].(string) != "http://127.0.0.1:19877/oauth/callback" {
|
||||
t.Errorf("redirectUris[0] = %q, want %q", redirectRaw[0], "http://127.0.0.1:19877/oauth/callback")
|
||||
}
|
||||
|
||||
// Verify response parsing
|
||||
if resp.ClientID != "test-client-id" {
|
||||
t.Errorf("ClientID = %q, want %q", resp.ClientID, "test-client-id")
|
||||
}
|
||||
if resp.ClientSecret != "test-client-secret" {
|
||||
t.Errorf("ClientSecret = %q, want %q", resp.ClientSecret, "test-client-secret")
|
||||
}
|
||||
if resp.ClientIDIssuedAt != 1700000000 {
|
||||
t.Errorf("ClientIDIssuedAt = %d, want %d", resp.ClientIDIssuedAt, 1700000000)
|
||||
}
|
||||
if resp.ClientSecretExpiresAt != 1700086400 {
|
||||
t.Errorf("ClientSecretExpiresAt = %d, want %d", resp.ClientSecretExpiresAt, 1700086400)
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteTransport redirects all requests to the test server URL.
|
||||
type rewriteTransport struct {
|
||||
base http.RoundTripper
|
||||
targetURL string
|
||||
}
|
||||
|
||||
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
target, _ := url.Parse(t.targetURL)
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
if t.base != nil {
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL(t *testing.T) {
|
||||
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
|
||||
endpoint := "https://oidc.us-east-1.amazonaws.com"
|
||||
redirectURI := "http://127.0.0.1:19877/oauth/callback"
|
||||
|
||||
authURL := buildAuthorizationURL(endpoint, "test-client-id", redirectURI, scopes, "random-state", "test-challenge")
|
||||
|
||||
// Verify colons and commas in scopes are percent-encoded
|
||||
if !strings.Contains(authURL, "codewhisperer%3Acompletions") {
|
||||
t.Errorf("expected colons in scopes to be percent-encoded, got: %s", authURL)
|
||||
}
|
||||
if !strings.Contains(authURL, "completions%2Ccodewhisperer") {
|
||||
t.Errorf("expected commas in scopes to be percent-encoded, got: %s", authURL)
|
||||
}
|
||||
|
||||
// Parse back and verify all parameters round-trip correctly
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(authURL, endpoint+"/authorize?") {
|
||||
t.Errorf("expected URL to start with %s/authorize?, got: %s", endpoint, authURL)
|
||||
}
|
||||
|
||||
q := parsed.Query()
|
||||
checks := map[string]string{
|
||||
"response_type": "code",
|
||||
"client_id": "test-client-id",
|
||||
"redirect_uri": redirectURI,
|
||||
"scopes": scopes,
|
||||
"state": "random-state",
|
||||
"code_challenge": "test-challenge",
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
for key, want := range checks {
|
||||
if got := q.Get(key); got != want {
|
||||
t.Errorf("%s = %q, want %q", key, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -29,7 +29,7 @@ type KiroTokenStorage struct {
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
// ClientSecret is the OAuth client secret (required for token refresh)
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
// Region is the AWS region
|
||||
// Region is the OIDC region for IDC login and token refresh
|
||||
Region string `json:"region,omitempty"`
|
||||
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
||||
StartURL string `json:"start_url,omitempty"`
|
||||
|
||||
@@ -200,36 +200,22 @@ func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
|
||||
}
|
||||
|
||||
// 解析各字段
|
||||
if v, ok := metadata["access_token"].(string); ok {
|
||||
token.AccessToken = v
|
||||
}
|
||||
if v, ok := metadata["refresh_token"].(string); ok {
|
||||
token.RefreshToken = v
|
||||
}
|
||||
if v, ok := metadata["client_id"].(string); ok {
|
||||
token.ClientID = v
|
||||
}
|
||||
if v, ok := metadata["client_secret"].(string); ok {
|
||||
token.ClientSecret = v
|
||||
}
|
||||
if v, ok := metadata["region"].(string); ok {
|
||||
token.Region = v
|
||||
}
|
||||
if v, ok := metadata["start_url"].(string); ok {
|
||||
token.StartURL = v
|
||||
}
|
||||
if v, ok := metadata["provider"].(string); ok {
|
||||
token.Provider = v
|
||||
}
|
||||
token.AccessToken, _ = metadata["access_token"].(string)
|
||||
token.RefreshToken, _ = metadata["refresh_token"].(string)
|
||||
token.ClientID, _ = metadata["client_id"].(string)
|
||||
token.ClientSecret, _ = metadata["client_secret"].(string)
|
||||
token.Region, _ = metadata["region"].(string)
|
||||
token.StartURL, _ = metadata["start_url"].(string)
|
||||
token.Provider, _ = metadata["provider"].(string)
|
||||
|
||||
// 解析时间字段
|
||||
if v, ok := metadata["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
|
||||
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
token.ExpiresAt = t
|
||||
}
|
||||
}
|
||||
if v, ok := metadata["last_refresh"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
if lastRefreshStr, ok := metadata["last_refresh"].(string); ok && lastRefreshStr != "" {
|
||||
if t, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil {
|
||||
token.LastVerified = t
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -51,14 +50,12 @@ type QuotaStatus struct {
|
||||
// UsageChecker provides methods for checking token quota usage.
|
||||
type UsageChecker struct {
|
||||
httpClient *http.Client
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// NewUsageChecker creates a new UsageChecker instance.
|
||||
func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +63,6 @@ func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
||||
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: client,
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,26 +76,23 @@ func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData)
|
||||
return nil, fmt.Errorf("access token is empty")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
// Use endpoint from profileArn if available
|
||||
endpoint := GetKiroAPIEndpointFromProfileArn(tokenData.ProfileArn)
|
||||
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", targetGetUsage)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
|
||||
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -40,8 +40,7 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *claude.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*claude.AuthenticationError](err); ok {
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == claude.ErrPortInUse.Type {
|
||||
os.Exit(claude.ErrPortInUse.Code)
|
||||
|
||||
@@ -22,6 +22,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewKimiAuthenticator(),
|
||||
sdkAuth.NewKiroAuthenticator(),
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
sdkAuth.NewKiloAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
@@ -32,8 +32,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
54
internal/cmd/kilo_login.go
Normal file
54
internal/cmd/kilo_login.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
)
|
||||
|
||||
// DoKiloLogin handles the Kilo device flow using the shared authentication manager.
|
||||
// It initiates the device-based authentication process for Kilo AI services and saves
|
||||
// the authentication tokens to the configured auth directory.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration
|
||||
// - options: Login options including browser behavior and prompts
|
||||
func DoKiloLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = func(prompt string) (string, error) {
|
||||
fmt.Print(prompt)
|
||||
var value string
|
||||
fmt.Scanln(&value)
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts)
|
||||
if err != nil {
|
||||
fmt.Printf("Kilo authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
|
||||
fmt.Println("Kilo authentication successful!")
|
||||
}
|
||||
@@ -206,3 +206,52 @@ func DoKiroImport(cfg *config.Config, options *LoginOptions) {
|
||||
}
|
||||
fmt.Println("Kiro token import successful!")
|
||||
}
|
||||
|
||||
func DoKiroIDCLogin(cfg *config.Config, options *LoginOptions, startURL, region, flow string) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
if startURL == "" {
|
||||
log.Errorf("Kiro IDC login requires --kiro-idc-start-url")
|
||||
fmt.Println("\nUsage: --kiro-idc-login --kiro-idc-start-url https://d-xxx.awsapps.com/start")
|
||||
return
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||
metadata := map[string]string{
|
||||
"start-url": startURL,
|
||||
"region": region,
|
||||
"flow": flow,
|
||||
}
|
||||
|
||||
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: metadata,
|
||||
Prompt: options.Prompt,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("Kiro IDC authentication failed: %v", err)
|
||||
fmt.Println("\nTroubleshooting:")
|
||||
fmt.Println("1. Make sure your IDC Start URL is correct")
|
||||
fmt.Println("2. Complete the authorization in the browser")
|
||||
fmt.Println("3. If auth code flow fails, try: --kiro-idc-flow device")
|
||||
return
|
||||
}
|
||||
|
||||
savedPath, err := manager.SaveAuth(record, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to save auth: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("Kiro IDC authentication successful!")
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -27,11 +28,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
geminiCLIApiClient = "gl-node/22.17.0"
|
||||
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
geminiCLIVersion = "v1internal"
|
||||
)
|
||||
|
||||
type projectSelectionRequiredError struct{}
|
||||
@@ -100,49 +98,74 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
var activatedProjects []string
|
||||
|
||||
useGoogleOne := false
|
||||
if trimmedProjectID == "" && promptFn != nil {
|
||||
fmt.Println("\nSelect login mode:")
|
||||
fmt.Println(" 1. Code Assist (GCP project, manual selection)")
|
||||
fmt.Println(" 2. Google One (personal account, auto-discover project)")
|
||||
choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ")
|
||||
if errPrompt == nil && strings.TrimSpace(choice) == "2" {
|
||||
useGoogleOne = true
|
||||
}
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
activatedProjects := make([]string, 0, len(projectSelections))
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
var projectErr *projectSelectionRequiredError
|
||||
if errors.As(errSetup, &projectErr) {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
if useGoogleOne {
|
||||
log.Info("Google One mode: auto-discovering project...")
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
autoProject := strings.TrimSpace(storage.ProjectID)
|
||||
if autoProject == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
log.Infof("Auto-discovered project: %s", autoProject)
|
||||
activatedProjects = []string{autoProject}
|
||||
} else {
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
}
|
||||
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
|
||||
storage.Auto = false
|
||||
@@ -235,7 +258,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
@@ -343,9 +407,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
|
||||
return fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
||||
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
@@ -564,7 +626,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||
@@ -585,7 +647,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
||||
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||
resp, errDo = httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||
@@ -617,7 +679,7 @@ func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStor
|
||||
return
|
||||
}
|
||||
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, false)
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true)
|
||||
|
||||
if record.Metadata == nil {
|
||||
record.Metadata = make(map[string]any)
|
||||
|
||||
60
internal/cmd/openai_device_login.go
Normal file
60
internal/cmd/openai_device_login.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codexLoginModeMetadataKey = "codex_login_mode"
|
||||
codexLoginModeDevice = "device"
|
||||
)
|
||||
|
||||
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||
// existing codex-login OAuth callback flow intact.
|
||||
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{
|
||||
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
}
|
||||
return
|
||||
}
|
||||
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("Codex device authentication successful!")
|
||||
}
|
||||
@@ -54,8 +54,7 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *codex.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
|
||||
@@ -44,8 +44,7 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
|
||||
}
|
||||
}
|
||||
|
||||
// StartServiceBackground starts the proxy service in a background goroutine
|
||||
// and returns a cancel function for shutdown and a done channel.
|
||||
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
|
||||
builder := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath(configPath).
|
||||
WithLocalManagementPassword(localPassword)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
service, err := builder.Build()
|
||||
if err != nil {
|
||||
log.Errorf("failed to build proxy service: %v", err)
|
||||
close(doneCh)
|
||||
return cancelFn, doneCh
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("proxy service exited with error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return cancelFn, doneCh
|
||||
}
|
||||
|
||||
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
|
||||
// when no configuration file is available.
|
||||
func WaitForCloudDeploy() {
|
||||
|
||||
@@ -69,6 +69,9 @@ type Config struct {
|
||||
|
||||
// RequestRetry defines the retry times when the request failed.
|
||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
||||
// Set to 0 or a negative value to keep trying all available credentials (legacy behavior).
|
||||
MaxRetryCredentials int `yaml:"max-retry-credentials" json:"max-retry-credentials"`
|
||||
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
||||
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
|
||||
|
||||
@@ -87,6 +90,10 @@ type Config struct {
|
||||
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
|
||||
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
|
||||
|
||||
// KiroFingerprint defines a global fingerprint configuration for all Kiro requests.
|
||||
// When set, all Kiro requests will use this fixed fingerprint instead of random generation.
|
||||
KiroFingerprint *KiroFingerprintConfig `yaml:"kiro-fingerprint,omitempty" json:"kiro-fingerprint,omitempty"`
|
||||
|
||||
// KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
|
||||
// Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
|
||||
KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
|
||||
@@ -97,6 +104,10 @@ type Config struct {
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values for Claude API requests.
|
||||
// These are used as fallbacks when the client does not send its own headers.
|
||||
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
|
||||
|
||||
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
|
||||
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
|
||||
|
||||
@@ -130,6 +141,15 @@ type Config struct {
|
||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||
type ClaudeHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
@@ -301,6 +321,10 @@ type CloakConfig struct {
|
||||
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||
// This can help bypass certain content filters.
|
||||
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||
|
||||
// CacheUserID controls whether Claude user_id values are cached per API key.
|
||||
// When false, a fresh random user_id is generated for every request.
|
||||
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeKey represents the configuration for a Claude API key,
|
||||
@@ -368,6 +392,9 @@ type CodexKey struct {
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
// Websockets enables the Responses API websocket transport for this credential.
|
||||
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
|
||||
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
@@ -457,6 +484,9 @@ type KiroKey struct {
|
||||
// Region is the AWS region (default: us-east-1).
|
||||
Region string `yaml:"region,omitempty" json:"region,omitempty"`
|
||||
|
||||
// StartURL is the IAM Identity Center (IDC) start URL for SSO login.
|
||||
StartURL string `yaml:"start-url,omitempty" json:"start-url,omitempty"`
|
||||
|
||||
// ProxyURL optionally overrides the global proxy for this configuration.
|
||||
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||
|
||||
@@ -469,6 +499,20 @@ type KiroKey struct {
|
||||
PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
|
||||
}
|
||||
|
||||
// KiroFingerprintConfig defines a global fingerprint configuration for Kiro requests.
|
||||
// When configured, all Kiro requests will use this fixed fingerprint instead of random generation.
|
||||
// Empty fields will fall back to random selection from built-in pools.
|
||||
type KiroFingerprintConfig struct {
|
||||
OIDCSDKVersion string `yaml:"oidc-sdk-version,omitempty" json:"oidc-sdk-version,omitempty"`
|
||||
RuntimeSDKVersion string `yaml:"runtime-sdk-version,omitempty" json:"runtime-sdk-version,omitempty"`
|
||||
StreamingSDKVersion string `yaml:"streaming-sdk-version,omitempty" json:"streaming-sdk-version,omitempty"`
|
||||
OSType string `yaml:"os-type,omitempty" json:"os-type,omitempty"`
|
||||
OSVersion string `yaml:"os-version,omitempty" json:"os-version,omitempty"`
|
||||
NodeVersion string `yaml:"node-version,omitempty" json:"node-version,omitempty"`
|
||||
KiroVersion string `yaml:"kiro-version,omitempty" json:"kiro-version,omitempty"`
|
||||
KiroHash string `yaml:"kiro-hash,omitempty" json:"kiro-hash,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
||||
type OpenAICompatibility struct {
|
||||
@@ -535,16 +579,6 @@ func LoadConfig(configFile string) (*Config, error) {
|
||||
// If optional is true and the file is missing, it returns an empty Config.
|
||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// // Log warning but don't fail - config loading should still work
|
||||
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
// } else if migrated {
|
||||
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
// }
|
||||
|
||||
// Read the entire configuration file into memory.
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
@@ -632,6 +666,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
}
|
||||
|
||||
if cfg.MaxRetryCredentials < 0 {
|
||||
cfg.MaxRetryCredentials = 0
|
||||
}
|
||||
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
@@ -743,22 +781,24 @@ func (cfg *Config) SanitizeOAuthModelAlias() {
|
||||
return
|
||||
}
|
||||
|
||||
// Inject default Kiro aliases if no user-configured kiro aliases exist
|
||||
// Inject channel defaults when the channel is absent in user config.
|
||||
// Presence is checked case-insensitively and includes explicit nil/empty markers.
|
||||
if cfg.OAuthModelAlias == nil {
|
||||
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
||||
}
|
||||
if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
|
||||
// Check case-insensitive too
|
||||
found := false
|
||||
hasChannel := func(channel string) bool {
|
||||
for k := range cfg.OAuthModelAlias {
|
||||
if strings.EqualFold(strings.TrimSpace(k), "kiro") {
|
||||
found = true
|
||||
break
|
||||
if strings.EqualFold(strings.TrimSpace(k), channel) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
||||
}
|
||||
return false
|
||||
}
|
||||
if !hasChannel("kiro") {
|
||||
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
||||
}
|
||||
if !hasChannel("github-copilot") {
|
||||
cfg.OAuthModelAlias["github-copilot"] = defaultGitHubCopilotAliases()
|
||||
}
|
||||
|
||||
if len(cfg.OAuthModelAlias) == 0 {
|
||||
@@ -767,7 +807,13 @@ func (cfg *Config) SanitizeOAuthModelAlias() {
|
||||
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
|
||||
for rawChannel, aliases := range cfg.OAuthModelAlias {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(aliases) == 0 {
|
||||
if channel == "" {
|
||||
continue
|
||||
}
|
||||
// Preserve channels that were explicitly set to empty/nil – they act
|
||||
// as "disabled" markers so default injection won't re-add them (#222).
|
||||
if len(aliases) == 0 {
|
||||
out[channel] = nil
|
||||
continue
|
||||
}
|
||||
seenAlias := make(map[string]struct{}, len(aliases))
|
||||
@@ -1620,9 +1666,6 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
|
||||
srcIdx := findMapKeyIndex(srcRoot, key)
|
||||
if srcIdx < 0 {
|
||||
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
|
||||
//
|
||||
// Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the
|
||||
// oauth-model-alias key is missing, migration will add the default antigravity aliases.
|
||||
// When users delete the last channel from oauth-model-alias via the management API,
|
||||
// we want that deletion to persist across hot reloads and restarts.
|
||||
if key == "oauth-model-alias" {
|
||||
|
||||
37
internal/config/oauth_model_alias_defaults.go
Normal file
37
internal/config/oauth_model_alias_defaults.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package config
|
||||
|
||||
// defaultKiroAliases returns default oauth-model-alias entries for Kiro.
|
||||
// These aliases expose standard Claude IDs for Kiro-prefixed upstream models.
|
||||
func defaultKiroAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
// Sonnet 4.6
|
||||
{Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||
// Sonnet 4.5
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
// Sonnet 4
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||
// Opus 4.6
|
||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||
// Opus 4.5
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||
// Haiku 4.5
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultGitHubCopilotAliases returns default oauth-model-alias entries for
|
||||
// GitHub Copilot Claude models. It exposes hyphen-style IDs used by clients.
|
||||
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
||||
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
||||
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
||||
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||
}
|
||||
}
|
||||
@@ -1,299 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// antigravityModelConversionTable maps old built-in aliases to actual model names
|
||||
// for the antigravity channel during migration.
|
||||
var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
}
|
||||
|
||||
// defaultKiroAliases returns the default oauth-model-alias configuration
|
||||
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
|
||||
// names so that clients like Claude Code can use standard names directly.
|
||||
func defaultKiroAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
// Sonnet 4.5
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
// Sonnet 4
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||
// Opus 4.6
|
||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||
// Opus 4.5
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||
// Haiku 4.5
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
// for the antigravity channel when neither field exists.
|
||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
|
||||
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
|
||||
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
|
||||
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
|
||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
||||
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
|
||||
}
|
||||
}
|
||||
|
||||
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
|
||||
// to oauth-model-alias at startup. Returns true if migration was performed.
|
||||
//
|
||||
// Migration flow:
|
||||
// 1. Check if oauth-model-alias exists -> skip migration
|
||||
// 2. Check if oauth-model-mappings exists -> convert and migrate
|
||||
// - For antigravity channel, convert old built-in aliases to actual model names
|
||||
//
|
||||
// 3. Neither exists -> add default antigravity config
|
||||
func MigrateOAuthModelAlias(configFile string) (bool, error) {
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Parse YAML into node tree to preserve structure
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
rootMap := root.Content[0]
|
||||
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if oauth-model-alias already exists
|
||||
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if oauth-model-mappings exists
|
||||
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
|
||||
if oldIdx >= 0 {
|
||||
// Migrate from old field
|
||||
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
|
||||
}
|
||||
|
||||
// Neither field exists - add default antigravity config
|
||||
return addDefaultAntigravityConfig(configFile, &root, rootMap)
|
||||
}
|
||||
|
||||
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
|
||||
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
|
||||
if oldIdx+1 >= len(rootMap.Content) {
|
||||
return false, nil
|
||||
}
|
||||
oldValue := rootMap.Content[oldIdx+1]
|
||||
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Parse the old aliases
|
||||
oldAliases := parseOldAliasNode(oldValue)
|
||||
if len(oldAliases) == 0 {
|
||||
// Remove the old field and write
|
||||
removeMapKeyByIndex(rootMap, oldIdx)
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// Convert model names for antigravity channel
|
||||
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
|
||||
for channel, entries := range oldAliases {
|
||||
converted := make([]OAuthModelAlias, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
newEntry := OAuthModelAlias{
|
||||
Name: entry.Name,
|
||||
Alias: entry.Alias,
|
||||
Fork: entry.Fork,
|
||||
}
|
||||
// Convert model names for antigravity channel
|
||||
if strings.EqualFold(channel, "antigravity") {
|
||||
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
|
||||
newEntry.Name = actual
|
||||
}
|
||||
}
|
||||
converted = append(converted, newEntry)
|
||||
}
|
||||
newAliases[channel] = converted
|
||||
}
|
||||
|
||||
// For antigravity channel, supplement missing default aliases
|
||||
if antigravityEntries, exists := newAliases["antigravity"]; exists {
|
||||
// Build a set of already configured model names (upstream names)
|
||||
configuredModels := make(map[string]bool, len(antigravityEntries))
|
||||
for _, entry := range antigravityEntries {
|
||||
configuredModels[entry.Name] = true
|
||||
}
|
||||
|
||||
// Add missing default aliases
|
||||
for _, defaultAlias := range defaultAntigravityAliases() {
|
||||
if !configuredModels[defaultAlias.Name] {
|
||||
antigravityEntries = append(antigravityEntries, defaultAlias)
|
||||
}
|
||||
}
|
||||
newAliases["antigravity"] = antigravityEntries
|
||||
}
|
||||
|
||||
// Build new node
|
||||
newNode := buildOAuthModelAliasNode(newAliases)
|
||||
|
||||
// Replace old key with new key and value
|
||||
rootMap.Content[oldIdx].Value = "oauth-model-alias"
|
||||
rootMap.Content[oldIdx+1] = newNode
|
||||
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// addDefaultAntigravityConfig adds the default antigravity configuration
|
||||
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
|
||||
defaults := map[string][]OAuthModelAlias{
|
||||
"antigravity": defaultAntigravityAliases(),
|
||||
}
|
||||
newNode := buildOAuthModelAliasNode(defaults)
|
||||
|
||||
// Add new key-value pair
|
||||
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
|
||||
rootMap.Content = append(rootMap.Content, keyNode, newNode)
|
||||
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// parseOldAliasNode parses the old oauth-model-mappings node structure
|
||||
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
|
||||
if node == nil || node.Kind != yaml.MappingNode {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string][]OAuthModelAlias)
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
channelNode := node.Content[i]
|
||||
entriesNode := node.Content[i+1]
|
||||
if channelNode == nil || entriesNode == nil {
|
||||
continue
|
||||
}
|
||||
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
|
||||
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
|
||||
continue
|
||||
}
|
||||
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
|
||||
for _, entryNode := range entriesNode.Content {
|
||||
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
|
||||
continue
|
||||
}
|
||||
entry := parseAliasEntry(entryNode)
|
||||
if entry.Name != "" && entry.Alias != "" {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
if len(entries) > 0 {
|
||||
result[channel] = entries
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseAliasEntry parses a single alias entry node
|
||||
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
|
||||
var entry OAuthModelAlias
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
keyNode := node.Content[i]
|
||||
valNode := node.Content[i+1]
|
||||
if keyNode == nil || valNode == nil {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
|
||||
case "name":
|
||||
entry.Name = strings.TrimSpace(valNode.Value)
|
||||
case "alias":
|
||||
entry.Alias = strings.TrimSpace(valNode.Value)
|
||||
case "fork":
|
||||
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
|
||||
}
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
|
||||
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
|
||||
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
for channel, entries := range aliases {
|
||||
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
|
||||
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
|
||||
for _, entry := range entries {
|
||||
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
entryNode.Content = append(entryNode.Content,
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
|
||||
)
|
||||
if entry.Fork {
|
||||
entryNode.Content = append(entryNode.Content,
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
|
||||
)
|
||||
}
|
||||
entriesNode.Content = append(entriesNode.Content, entryNode)
|
||||
}
|
||||
node.Content = append(node.Content, channelNode, entriesNode)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
|
||||
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
|
||||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||||
return
|
||||
}
|
||||
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
|
||||
return
|
||||
}
|
||||
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
|
||||
}
|
||||
|
||||
// writeYAMLNode writes the YAML node tree back to file
|
||||
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
|
||||
f, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := yaml.NewEncoder(f)
|
||||
enc.SetIndent(2)
|
||||
if err := enc.Encode(root); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@@ -1,245 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `oauth-model-alias:
|
||||
gemini-cli:
|
||||
- name: "gemini-2.5-pro"
|
||||
alias: "g2.5p"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration when oauth-model-alias already exists")
|
||||
}
|
||||
|
||||
// Verify file unchanged
|
||||
data, _ := os.ReadFile(configFile)
|
||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
||||
t.Fatal("file should still contain oauth-model-alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `oauth-model-mappings:
|
||||
gemini-cli:
|
||||
- name: "gemini-2.5-pro"
|
||||
alias: "g2.5p"
|
||||
fork: true
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify new field exists and old field removed
|
||||
data, _ := os.ReadFile(configFile)
|
||||
if strings.Contains(string(data), "oauth-model-mappings:") {
|
||||
t.Fatal("old field should be removed")
|
||||
}
|
||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
||||
t.Fatal("new field should exist")
|
||||
}
|
||||
|
||||
// Parse and verify structure
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Use old model names that should be converted
|
||||
content := `oauth-model-mappings:
|
||||
antigravity:
|
||||
- name: "gemini-2.5-computer-use-preview-10-2025"
|
||||
alias: "computer-use"
|
||||
- name: "gemini-3-pro-preview"
|
||||
alias: "g3p"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify model names were converted
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
||||
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
|
||||
}
|
||||
if !strings.Contains(content, "gemini-3-pro-high") {
|
||||
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
|
||||
}
|
||||
|
||||
// Verify missing default aliases were supplemented
|
||||
if !strings.Contains(content, "gemini-3-pro-image") {
|
||||
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
|
||||
}
|
||||
if !strings.Contains(content, "gemini-3-flash") {
|
||||
t.Fatal("expected missing default alias gemini-3-flash to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-sonnet-4-5") {
|
||||
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-6-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `debug: true
|
||||
port: 8080
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to add default config")
|
||||
}
|
||||
|
||||
// Verify default antigravity config was added
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "oauth-model-alias:") {
|
||||
t.Fatal("expected oauth-model-alias to be added")
|
||||
}
|
||||
if !strings.Contains(content, "antigravity:") {
|
||||
t.Fatal("expected antigravity channel to be added")
|
||||
}
|
||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
||||
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `debug: true
|
||||
port: 8080
|
||||
oauth-model-mappings:
|
||||
gemini-cli:
|
||||
- name: "test"
|
||||
alias: "t"
|
||||
api-keys:
|
||||
- "key1"
|
||||
- "key2"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify other config preserved
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "debug: true") {
|
||||
t.Fatal("expected debug field to be preserved")
|
||||
}
|
||||
if !strings.Contains(content, "port: 8080") {
|
||||
t.Fatal("expected port field to be preserved")
|
||||
}
|
||||
if !strings.Contains(content, "api-keys:") {
|
||||
t.Fatal("expected api-keys field to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for nonexistent file: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration for empty file")
|
||||
}
|
||||
}
|
||||
@@ -107,6 +107,44 @@ func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultGitHubCopilotAliases(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||
if len(copilotAliases) == 0 {
|
||||
t.Fatal("expected default github-copilot aliases to be injected")
|
||||
}
|
||||
|
||||
aliasSet := make(map[string]bool, len(copilotAliases))
|
||||
for _, a := range copilotAliases {
|
||||
aliasSet[a.Alias] = true
|
||||
if !a.Fork {
|
||||
t.Fatalf("expected all default github-copilot aliases to have fork=true, got fork=false for %q", a.Alias)
|
||||
}
|
||||
}
|
||||
expectedAliases := []string{
|
||||
"claude-haiku-4-5",
|
||||
"claude-opus-4-1",
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-6",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4-6",
|
||||
}
|
||||
for _, expected := range expectedAliases {
|
||||
if !aliasSet[expected] {
|
||||
t.Fatalf("expected default github-copilot alias %q to be present", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
||||
// When user has configured kiro aliases, defaults should NOT be injected
|
||||
cfg := &Config{
|
||||
@@ -128,6 +166,88 @@ func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserGitHubCopilotAliases(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"github-copilot": {
|
||||
{Name: "claude-opus-4.6", Alias: "my-opus", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||
if len(copilotAliases) != 1 {
|
||||
t.Fatalf("expected 1 user-configured github-copilot alias, got %d", len(copilotAliases))
|
||||
}
|
||||
if copilotAliases[0].Alias != "my-opus" {
|
||||
t.Fatalf("expected user alias to be preserved, got %q", copilotAliases[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||
// When user explicitly deletes kiro aliases (key exists with nil value),
|
||||
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": nil, // explicitly deleted
|
||||
"codex": {{Name: "gpt-5", Alias: "g5"}},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) != 0 {
|
||||
t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases))
|
||||
}
|
||||
// The key itself must still be present to prevent re-injection on next reload
|
||||
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
|
||||
t.Fatal("expected kiro key to be preserved as nil marker after sanitization")
|
||||
}
|
||||
// Other channels should be unaffected
|
||||
if len(cfg.OAuthModelAlias["codex"]) != 1 {
|
||||
t.Fatal("expected codex aliases to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_GitHubCopilotDoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"github-copilot": nil, // explicitly deleted
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||
if len(copilotAliases) != 0 {
|
||||
t.Fatalf("expected github-copilot aliases to remain empty after explicit deletion, got %d aliases", len(copilotAliases))
|
||||
}
|
||||
if _, exists := cfg.OAuthModelAlias["github-copilot"]; !exists {
|
||||
t.Fatal("expected github-copilot key to be preserved as nil marker after sanitization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
|
||||
// Same as above but with empty slice instead of nil (PUT with empty body).
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": {}, // explicitly set to empty
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
if len(cfg.OAuthModelAlias["kiro"]) != 0 {
|
||||
t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"]))
|
||||
}
|
||||
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
|
||||
t.Fatal("expected kiro key to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
|
||||
// When OAuthModelAlias is nil, kiro defaults should still be injected
|
||||
cfg := &Config{}
|
||||
|
||||
@@ -20,6 +20,10 @@ type SDKConfig struct {
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
|
||||
// Default is false (disabled).
|
||||
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
|
||||
|
||||
@@ -27,4 +27,7 @@ const (
|
||||
|
||||
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
|
||||
Kiro = "kiro"
|
||||
|
||||
// Kilo represents the Kilo AI provider identifier.
|
||||
Kilo = "kilo"
|
||||
)
|
||||
|
||||
@@ -1 +1 @@
|
||||
[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude.","cache_control":{"type":"ephemeral"}}]
|
||||
[{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}]
|
||||
@@ -1,6 +1,7 @@
|
||||
package misc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
||||
func LogCredentialSeparator() {
|
||||
log.Debug(credentialSeparator)
|
||||
}
|
||||
|
||||
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||
var data map[string]any
|
||||
|
||||
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||
if srcMap, ok := source.(map[string]any); ok {
|
||||
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||
for k, v := range srcMap {
|
||||
data[k] = v
|
||||
}
|
||||
} else {
|
||||
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||
temp, err := json.Marshal(source)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(temp, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra metadata
|
||||
if metadata != nil {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
for k, v := range metadata {
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -4,10 +4,98 @@
|
||||
package misc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
|
||||
GeminiCLIVersion = "0.31.0"
|
||||
|
||||
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
|
||||
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
|
||||
)
|
||||
|
||||
// geminiCLIOS maps Go runtime OS names to the Node.js-style platform strings used by Gemini CLI.
|
||||
func geminiCLIOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return "win32"
|
||||
default:
|
||||
return runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// geminiCLIArch maps Go runtime architecture names to the Node.js-style arch strings used by Gemini CLI.
|
||||
func geminiCLIArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
// GeminiCLIUserAgent returns a User-Agent string that matches the Gemini CLI format.
|
||||
// The model parameter is included in the UA; pass "" or "unknown" when the model is not applicable.
|
||||
func GeminiCLIUserAgent(model string) string {
|
||||
if model == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
|
||||
}
|
||||
|
||||
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
|
||||
// proxy infrastructure, client identity, or browser fingerprints from an
|
||||
// outgoing request. This ensures requests to upstream services look like they
|
||||
// originate directly from a native client rather than a third-party client
|
||||
// behind a reverse proxy.
|
||||
func ScrubProxyAndFingerprintHeaders(req *http.Request) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// --- Proxy tracing headers ---
|
||||
req.Header.Del("X-Forwarded-For")
|
||||
req.Header.Del("X-Forwarded-Host")
|
||||
req.Header.Del("X-Forwarded-Proto")
|
||||
req.Header.Del("X-Forwarded-Port")
|
||||
req.Header.Del("X-Real-IP")
|
||||
req.Header.Del("Forwarded")
|
||||
req.Header.Del("Via")
|
||||
|
||||
// --- Client identity headers ---
|
||||
req.Header.Del("X-Title")
|
||||
req.Header.Del("X-Stainless-Lang")
|
||||
req.Header.Del("X-Stainless-Package-Version")
|
||||
req.Header.Del("X-Stainless-Os")
|
||||
req.Header.Del("X-Stainless-Arch")
|
||||
req.Header.Del("X-Stainless-Runtime")
|
||||
req.Header.Del("X-Stainless-Runtime-Version")
|
||||
req.Header.Del("Http-Referer")
|
||||
req.Header.Del("Referer")
|
||||
|
||||
// --- Browser / Chromium fingerprint headers ---
|
||||
// These are sent by Electron-based clients (e.g. CherryStudio) using the
|
||||
// Fetch API, but NOT by Node.js https module (which Antigravity uses).
|
||||
req.Header.Del("Sec-Ch-Ua")
|
||||
req.Header.Del("Sec-Ch-Ua-Mobile")
|
||||
req.Header.Del("Sec-Ch-Ua-Platform")
|
||||
req.Header.Del("Sec-Fetch-Mode")
|
||||
req.Header.Del("Sec-Fetch-Site")
|
||||
req.Header.Del("Sec-Fetch-Dest")
|
||||
req.Header.Del("Priority")
|
||||
|
||||
// --- Encoding negotiation ---
|
||||
// Antigravity (Node.js) sends "gzip, deflate, br" by default;
|
||||
// Electron-based clients may add "zstd" which is a fingerprint mismatch.
|
||||
req.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
// EnsureHeader ensures that a header exists in the target header map by checking
|
||||
// multiple sources in order of priority: source headers, existing target headers,
|
||||
// and finally the default value. It only sets the header if it's not already present
|
||||
|
||||
21
internal/registry/kilo_models.go
Normal file
21
internal/registry/kilo_models.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Package registry provides model definitions for various AI service providers.
|
||||
package registry
|
||||
|
||||
// GetKiloModels returns the Kilo model definitions
|
||||
func GetKiloModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
// --- Base Models ---
|
||||
{
|
||||
ID: "kilo/auto",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "kilo",
|
||||
Type: "kilo",
|
||||
DisplayName: "Kilo Auto",
|
||||
Description: "Automatic model selection by Kilo",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
// - codex
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kimi
|
||||
// - kiro
|
||||
// - kilo
|
||||
// - github-copilot
|
||||
// - kiro
|
||||
// - amazonq
|
||||
@@ -43,10 +45,14 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetQwenModels()
|
||||
case "iflow":
|
||||
return GetIFlowModels()
|
||||
case "kimi":
|
||||
return GetKimiModels()
|
||||
case "github-copilot":
|
||||
return GetGitHubCopilotModels()
|
||||
case "kiro":
|
||||
return GetKiroModels()
|
||||
case "kilo":
|
||||
return GetKiloModels()
|
||||
case "amazonq":
|
||||
return GetAmazonQModels()
|
||||
case "antigravity":
|
||||
@@ -93,8 +99,10 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
GetKimiModels(),
|
||||
GetGitHubCopilotModels(),
|
||||
GetKiroModels(),
|
||||
GetKiloModels(),
|
||||
GetAmazonQModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
@@ -121,7 +129,19 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||
func GetGitHubCopilotModels() []*ModelInfo {
|
||||
now := int64(1732752000) // 2024-11-27
|
||||
return []*ModelInfo{
|
||||
gpt4oEntries := []struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
Description string
|
||||
}{
|
||||
{ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"},
|
||||
{ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"},
|
||||
{ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"},
|
||||
{ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"},
|
||||
{ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"},
|
||||
}
|
||||
|
||||
models := []*ModelInfo{
|
||||
{
|
||||
ID: "gpt-4.1",
|
||||
Object: "model",
|
||||
@@ -133,6 +153,23 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
}
|
||||
|
||||
for _, entry := range gpt4oEntries {
|
||||
models = append(models, &ModelInfo{
|
||||
ID: entry.ID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: entry.DisplayName,
|
||||
Description: entry.Description,
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
})
|
||||
}
|
||||
|
||||
return append(models, []*ModelInfo{
|
||||
{
|
||||
ID: "gpt-5",
|
||||
Object: "model",
|
||||
@@ -144,6 +181,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-mini",
|
||||
@@ -156,6 +194,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-codex",
|
||||
@@ -168,6 +207,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1",
|
||||
@@ -180,6 +220,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex",
|
||||
@@ -192,6 +233,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-mini",
|
||||
@@ -204,6 +246,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-max",
|
||||
@@ -216,6 +259,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2",
|
||||
@@ -228,6 +272,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
@@ -240,6 +285,20 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.3 Codex",
|
||||
Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
@@ -313,6 +372,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4.6",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Claude Sonnet 4.6",
|
||||
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
Object: "model",
|
||||
@@ -335,6 +406,17 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Gemini 3.1 Pro (Preview)",
|
||||
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -369,7 +451,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
}
|
||||
}...)
|
||||
}
|
||||
|
||||
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
||||
@@ -400,6 +482,18 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1739836800, // 2025-02-18
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Sonnet 4.6",
|
||||
Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5",
|
||||
Object: "model",
|
||||
@@ -448,6 +542,87 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
// --- 第三方模型 (通过 Kiro 接入) ---
|
||||
{
|
||||
ID: "kiro-deepseek-3-2",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro DeepSeek 3.2",
|
||||
Description: "DeepSeek 3.2 via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-minimax-m2-1",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro MiniMax M2.1",
|
||||
Description: "MiniMax M2.1 via Kiro",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-qwen3-coder-next",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Qwen3 Coder Next",
|
||||
Description: "Qwen3 Coder Next via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4o",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4o",
|
||||
Description: "OpenAI GPT-4o via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4",
|
||||
Description: "OpenAI GPT-4 via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 8192,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4-turbo",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4 Turbo",
|
||||
Description: "OpenAI GPT-4 Turbo via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-3-5-turbo",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-3.5 Turbo",
|
||||
Description: "OpenAI GPT-3.5 Turbo via Kiro",
|
||||
ContextLength: 16384,
|
||||
MaxCompletionTokens: 4096,
|
||||
},
|
||||
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||
{
|
||||
ID: "kiro-claude-opus-4-6-agentic",
|
||||
@@ -461,6 +636,18 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-sonnet-4-6-agentic",
|
||||
Object: "model",
|
||||
Created: 1739836800, // 2025-02-18
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
|
||||
Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5-agentic",
|
||||
Object: "model",
|
||||
@@ -509,6 +696,42 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-deepseek-3-2-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro DeepSeek 3.2 (Agentic)",
|
||||
Description: "DeepSeek 3.2 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-minimax-m2-1-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro MiniMax M2.1 (Agentic)",
|
||||
Description: "MiniMax M2.1 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-qwen3-coder-next-agentic",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Qwen3 Coder Next (Agentic)",
|
||||
Description: "Qwen3 Coder Next optimized for coding agents (chunked writes)",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1771372800, // 2026-02-17
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Object: "model",
|
||||
@@ -38,6 +49,18 @@ func GetClaudeModels() []*ModelInfo {
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 1000000,
|
||||
MaxCompletionTokens: 128000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high", "max"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1771286400, // 2026-02-17
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
Description: "Best combination of speed and intelligence",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
@@ -173,6 +196,21 @@ func GetGeminiModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -283,6 +321,21 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
Object: "model",
|
||||
@@ -425,6 +478,21 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -506,6 +574,21 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -742,6 +825,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex-spark",
|
||||
Object: "model",
|
||||
Created: 1770912000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex Spark",
|
||||
Description: "Ultra-fast coding model.",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -774,6 +871,19 @@ func GetQwenModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "coder-model",
|
||||
Object: "model",
|
||||
Created: 1771171200,
|
||||
OwnedBy: "qwen",
|
||||
Type: "qwen",
|
||||
Version: "3.5",
|
||||
DisplayName: "Qwen 3.5 Plus",
|
||||
Description: "efficient hybrid model with leading coding performance",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "vision-model",
|
||||
Object: "model",
|
||||
@@ -806,19 +916,12 @@ func GetIFlowModels() []*ModelInfo {
|
||||
Created int64
|
||||
Thinking *ThinkingSupport
|
||||
}{
|
||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
@@ -827,10 +930,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -859,18 +959,17 @@ type AntigravityModelConfig struct {
|
||||
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
return map[string]*AntigravityModelConfig{
|
||||
// "rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ type ModelInfo struct {
|
||||
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||
// SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
|
||||
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
|
||||
// SupportedInputModalities lists supported input modalities (e.g., TEXT, IMAGE, VIDEO, AUDIO)
|
||||
SupportedInputModalities []string `json:"supportedInputModalities,omitempty"`
|
||||
// SupportedOutputModalities lists supported output modalities (e.g., TEXT, IMAGE)
|
||||
SupportedOutputModalities []string `json:"supportedOutputModalities,omitempty"`
|
||||
|
||||
// Thinking holds provider-specific reasoning/thinking budget capabilities.
|
||||
// This is optional and currently used for Gemini thinking budget normalization.
|
||||
@@ -501,8 +505,11 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if len(model.SupportedEndpoints) > 0 {
|
||||
copyModel.SupportedEndpoints = append([]string(nil), model.SupportedEndpoints...)
|
||||
if len(model.SupportedInputModalities) > 0 {
|
||||
copyModel.SupportedInputModalities = append([]string(nil), model.SupportedInputModalities...)
|
||||
}
|
||||
if len(model.SupportedOutputModalities) > 0 {
|
||||
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
||||
}
|
||||
return ©Model
|
||||
}
|
||||
@@ -601,8 +608,7 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
@@ -1090,6 +1096,12 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
||||
}
|
||||
if len(model.SupportedInputModalities) > 0 {
|
||||
result["supportedInputModalities"] = model.SupportedInputModalities
|
||||
}
|
||||
if len(model.SupportedOutputModalities) > 0 {
|
||||
result["supportedOutputModalities"] = model.SupportedOutputModalities
|
||||
}
|
||||
return result
|
||||
|
||||
default:
|
||||
|
||||
@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(first wsrelay.StreamEvent) {
|
||||
defer close(out)
|
||||
var param any
|
||||
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
}
|
||||
}(firstEvent)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -45,17 +46,87 @@ const (
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
||||
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
||||
)
|
||||
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
||||
// from any antigravity auth. Empty fetches never overwrite this cache.
|
||||
antigravityPrimaryModelsCache struct {
|
||||
mu sync.RWMutex
|
||||
models []*registry.ModelInfo
|
||||
}
|
||||
)
|
||||
|
||||
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*registry.ModelInfo, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil || strings.TrimSpace(model.ID) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, cloneAntigravityModelInfo(model))
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
||||
if model == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *model
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||
}
|
||||
if model.Thinking != nil {
|
||||
thinkingClone := *model.Thinking
|
||||
if len(model.Thinking.Levels) > 0 {
|
||||
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||
}
|
||||
clone.Thinking = &thinkingClone
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
|
||||
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
||||
cloned := cloneAntigravityModels(models)
|
||||
if len(cloned) == 0 {
|
||||
return false
|
||||
}
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = cloned
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
antigravityPrimaryModelsCache.mu.RLock()
|
||||
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
||||
antigravityPrimaryModelsCache.mu.RUnlock()
|
||||
return cloned
|
||||
}
|
||||
|
||||
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
||||
models := loadAntigravityPrimaryModels()
|
||||
if len(models) > 0 {
|
||||
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
type AntigravityExecutor struct {
|
||||
cfg *config.Config
|
||||
@@ -72,6 +143,62 @@ func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
|
||||
return &AntigravityExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// antigravityTransport is a singleton HTTP/1.1 transport shared by all Antigravity requests.
|
||||
// It is initialized once via antigravityTransportOnce to avoid leaking a new connection pool
|
||||
// (and the goroutines managing it) on every request.
|
||||
var (
|
||||
antigravityTransport *http.Transport
|
||||
antigravityTransportOnce sync.Once
|
||||
)
|
||||
|
||||
func cloneTransportWithHTTP11(base *http.Transport) *http.Transport {
|
||||
if base == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := base.Clone()
|
||||
clone.ForceAttemptHTTP2 = false
|
||||
// Wipe TLSNextProto to prevent implicit HTTP/2 upgrade.
|
||||
clone.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
|
||||
if clone.TLSClientConfig == nil {
|
||||
clone.TLSClientConfig = &tls.Config{}
|
||||
} else {
|
||||
clone.TLSClientConfig = clone.TLSClientConfig.Clone()
|
||||
}
|
||||
// Actively advertise only HTTP/1.1 in the ALPN handshake.
|
||||
clone.TLSClientConfig.NextProtos = []string{"http/1.1"}
|
||||
return clone
|
||||
}
|
||||
|
||||
// initAntigravityTransport creates the shared HTTP/1.1 transport exactly once.
|
||||
func initAntigravityTransport() {
|
||||
base, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
base = &http.Transport{}
|
||||
}
|
||||
antigravityTransport = cloneTransportWithHTTP11(base)
|
||||
}
|
||||
|
||||
// newAntigravityHTTPClient creates an HTTP client specifically for Antigravity,
|
||||
// enforcing HTTP/1.1 by disabling HTTP/2 to perfectly mimic Node.js https defaults.
|
||||
// The underlying Transport is a singleton to avoid leaking connection pools.
|
||||
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
antigravityTransportOnce.Do(initAntigravityTransport)
|
||||
|
||||
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
// If no transport is set, use the shared HTTP/1.1 transport.
|
||||
if client.Transport == nil {
|
||||
client.Transport = antigravityTransport
|
||||
return client
|
||||
}
|
||||
|
||||
// Preserve proxy settings from proxy-aware transports while forcing HTTP/1.1.
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
client.Transport = cloneTransportWithHTTP11(transport)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
|
||||
|
||||
@@ -92,6 +219,8 @@ func (e *AntigravityExecutor) PrepareRequest(req *http.Request, auth *cliproxyau
|
||||
}
|
||||
|
||||
// HttpRequest injects Antigravity credentials into the request and executes it.
|
||||
// It uses a whitelist approach: all incoming headers are stripped and only
|
||||
// the minimum set required by the Antigravity protocol is explicitly set.
|
||||
func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("antigravity executor: request is nil")
|
||||
@@ -100,10 +229,29 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
|
||||
// --- Whitelist: save only the headers we need from the original request ---
|
||||
contentType := httpReq.Header.Get("Content-Type")
|
||||
|
||||
// Wipe ALL incoming headers
|
||||
for k := range httpReq.Header {
|
||||
delete(httpReq.Header, k)
|
||||
}
|
||||
|
||||
// --- Set only the headers Antigravity actually sends ---
|
||||
if contentType != "" {
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
// Content-Length is managed automatically by Go's http.Client from the Body
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
httpReq.Close = true // sends Connection: close
|
||||
|
||||
// Inject Authorization: Bearer <token>
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
@@ -115,7 +263,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
isClaude := strings.Contains(strings.ToLower(baseModel), "claude")
|
||||
|
||||
if isClaude || strings.Contains(baseModel, "gemini-3-pro") {
|
||||
if isClaude || strings.Contains(baseModel, "gemini-3-pro") || strings.Contains(baseModel, "gemini-3.1-flash-image") {
|
||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
@@ -150,7 +298,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||
|
||||
@@ -232,7 +380,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -292,7 +440,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||
|
||||
@@ -436,7 +584,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
@@ -645,7 +793,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -684,7 +832,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||
|
||||
@@ -775,7 +923,6 @@ attemptLoop:
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -820,7 +967,7 @@ attemptLoop:
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
@@ -887,7 +1034,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
@@ -918,10 +1065,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if errReq != nil {
|
||||
return cliproxyexecutor.Response{}, errReq
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
@@ -968,7 +1115,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
@@ -1008,21 +1155,33 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
return nil
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
return nil
|
||||
|
||||
var payload []byte
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||
}
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
payload = []byte(`{}`)
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
@@ -1033,13 +1192,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
@@ -1051,19 +1210,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -1075,7 +1242,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
continue
|
||||
}
|
||||
switch modelID {
|
||||
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
|
||||
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||
continue
|
||||
}
|
||||
modelCfg := modelConfig[modelID]
|
||||
@@ -1097,6 +1264,29 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
OwnedBy: antigravityAuthType,
|
||||
Type: antigravityAuthType,
|
||||
}
|
||||
|
||||
// Build input modalities from upstream capability flags.
|
||||
inputModalities := []string{"TEXT"}
|
||||
if modelData.Get("supportsImages").Bool() {
|
||||
inputModalities = append(inputModalities, "IMAGE")
|
||||
}
|
||||
if modelData.Get("supportsVideo").Bool() {
|
||||
inputModalities = append(inputModalities, "VIDEO")
|
||||
}
|
||||
modelInfo.SupportedInputModalities = inputModalities
|
||||
modelInfo.SupportedOutputModalities = []string{"TEXT"}
|
||||
|
||||
// Token limits from upstream.
|
||||
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||
modelInfo.InputTokenLimit = int(maxTok)
|
||||
}
|
||||
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||
modelInfo.OutputTokenLimit = int(maxOut)
|
||||
}
|
||||
|
||||
// Supported generation methods (Gemini v1beta convention).
|
||||
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
|
||||
|
||||
// Look up Thinking support from static config using upstream model name.
|
||||
if modelCfg != nil {
|
||||
if modelCfg.Thinking != nil {
|
||||
@@ -1108,9 +1298,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
storeAntigravityPrimaryModels(models)
|
||||
return models
|
||||
}
|
||||
return nil
|
||||
return fallbackAntigravityPrimaryModels()
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||
@@ -1155,10 +1354,11 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
|
||||
return auth, errReq
|
||||
}
|
||||
httpReq.Header.Set("Host", "oauth2.googleapis.com")
|
||||
httpReq.Header.Set("User-Agent", defaultAntigravityAgent)
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
// Real Antigravity uses Go's default User-Agent for OAuth token refresh
|
||||
httpReq.Header.Set("User-Agent", "Go-http-client/2.0")
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
return auth, errDo
|
||||
@@ -1229,7 +1429,7 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au
|
||||
return nil
|
||||
}
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
|
||||
if errFetch != nil {
|
||||
return errFetch
|
||||
@@ -1283,7 +1483,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high")
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro")
|
||||
payloadStr := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||
@@ -1297,18 +1497,18 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||
}
|
||||
|
||||
if useAntigravitySchema {
|
||||
systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
// if useAntigravitySchema {
|
||||
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
|
||||
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
// for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
@@ -1320,14 +1520,10 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
httpReq.Close = true
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
if stream {
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
@@ -1539,7 +1735,16 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
||||
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||
template, _ = sjson.Set(template, "requestType", "agent")
|
||||
|
||||
isImageModel := strings.Contains(modelName, "image")
|
||||
|
||||
var reqType string
|
||||
if isImageModel {
|
||||
reqType = "image_gen"
|
||||
} else {
|
||||
reqType = "agent"
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestType", reqType)
|
||||
|
||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||
if projectID != "" {
|
||||
@@ -1547,8 +1752,13 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
|
||||
if isImageModel {
|
||||
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
}
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
||||
@@ -1562,6 +1772,10 @@ func generateRequestID() string {
|
||||
return "agent-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func generateImageGenRequestID() string {
|
||||
return fmt.Sprintf("image_gen/%d/%s/12", time.Now().UnixMilli(), uuid.NewString())
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
randSourceMutex.Lock()
|
||||
n := randSource.Int63n(9_000_000_000_000_000_000)
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func resetAntigravityPrimaryModelsCacheForTest() {
|
||||
antigravityPrimaryModelsCache.mu.Lock()
|
||||
antigravityPrimaryModelsCache.models = nil
|
||||
antigravityPrimaryModelsCache.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
seed := []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4-5"},
|
||||
{ID: "gemini-2.5-pro"},
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
||||
t.Fatal("expected non-empty model list to update primary cache")
|
||||
}
|
||||
|
||||
if updated := storeAntigravityPrimaryModels(nil); updated {
|
||||
t.Fatal("expected nil model list not to overwrite primary cache")
|
||||
}
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
||||
t.Fatal("expected empty model list not to overwrite primary cache")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected cached model count 2, got %d", len(got))
|
||||
}
|
||||
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
||||
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
||||
resetAntigravityPrimaryModelsCacheForTest()
|
||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
||||
|
||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
||||
ID: "gpt-5",
|
||||
DisplayName: "GPT-5",
|
||||
SupportedGenerationMethods: []string{"generateContent"},
|
||||
SupportedParameters: []string{"temperature"},
|
||||
Thinking: ®istry.ThinkingSupport{
|
||||
Levels: []string{"high"},
|
||||
},
|
||||
}}); !updated {
|
||||
t.Fatal("expected model cache update")
|
||||
}
|
||||
|
||||
got := loadAntigravityPrimaryModels()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one cached model, got %d", len(got))
|
||||
}
|
||||
got[0].ID = "mutated-id"
|
||||
if len(got[0].SupportedGenerationMethods) > 0 {
|
||||
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
||||
}
|
||||
if len(got[0].SupportedParameters) > 0 {
|
||||
got[0].SupportedParameters[0] = "mutated-parameter"
|
||||
}
|
||||
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
||||
got[0].Thinking.Levels[0] = "mutated-level"
|
||||
}
|
||||
|
||||
again := loadAntigravityPrimaryModels()
|
||||
if len(again) != 1 {
|
||||
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
||||
}
|
||||
if again[0].ID != "gpt-5" {
|
||||
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
||||
}
|
||||
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
||||
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
||||
}
|
||||
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
||||
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
||||
}
|
||||
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
||||
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,9 +2,19 @@ package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
@@ -25,6 +35,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
|
||||
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
|
||||
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
@@ -37,6 +59,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [{"name": "Read"}, {"name": "Write"}],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
|
||||
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search"},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"tool_choice": {"type": "tool", "name": "web_search"}
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
@@ -49,6 +162,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
|
||||
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
|
||||
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
@@ -61,3 +186,400 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
|
||||
payload := bytes.TrimSpace(out)
|
||||
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
|
||||
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "proxy_mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
var userIDs []string
|
||||
var requestModels []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
userID := gjson.GetBytes(body, "metadata.user_id").String()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
userIDs = append(userIDs, userID)
|
||||
requestModels = append(requestModels, model)
|
||||
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
|
||||
|
||||
cacheEnabled := true
|
||||
executor := NewClaudeExecutor(&config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{
|
||||
APIKey: "key-123",
|
||||
BaseURL: server.URL,
|
||||
Cloak: &config.CloakConfig{
|
||||
CacheUserID: &cacheEnabled,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
|
||||
for _, model := range models {
|
||||
t.Logf("Sending request for model: %s", model)
|
||||
modelPayload, _ := sjson.SetBytes(payload, "model", model)
|
||||
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: model,
|
||||
Payload: modelPayload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute(%s) error: %v", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(userIDs) != 2 {
|
||||
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||
}
|
||||
if userIDs[0] == "" || userIDs[1] == "" {
|
||||
t.Fatal("expected user_id to be populated")
|
||||
}
|
||||
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
|
||||
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
|
||||
if userIDs[0] != userIDs[1] {
|
||||
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
|
||||
}
|
||||
if !isValidUserID(userIDs[0]) {
|
||||
t.Fatalf("user_id %q is not valid", userIDs[0])
|
||||
}
|
||||
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
var userIDs []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute call %d error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(userIDs) != 2 {
|
||||
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||
}
|
||||
if userIDs[0] == "" || userIDs[1] == "" {
|
||||
t.Fatal("expected user_id to be populated")
|
||||
}
|
||||
if userIDs[0] == userIDs[1] {
|
||||
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
|
||||
}
|
||||
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
|
||||
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
|
||||
if got != "mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
|
||||
// tool_result.content can be a string - should not be processed
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
|
||||
if got != "plain string result" {
|
||||
t.Fatalf("string content should remain unchanged = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "web_search" {
|
||||
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],
|
||||
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||
}`)
|
||||
|
||||
out := normalizeCacheControlTTL(payload)
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" {
|
||||
t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h")
|
||||
}
|
||||
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
|
||||
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}}
|
||||
],
|
||||
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
|
||||
"messages": [
|
||||
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]},
|
||||
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := enforceCacheControlLimit(payload, 4)
|
||||
|
||||
if got := countCacheControls(out); got != 4 {
|
||||
t.Fatalf("cache_control count = %d, want 4", got)
|
||||
}
|
||||
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
|
||||
}
|
||||
if !gjson.GetBytes(out, "tools.1.cache_control").Exists() {
|
||||
t.Fatalf("tools.1.cache_control (last tool) should be preserved")
|
||||
}
|
||||
if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() {
|
||||
t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t3","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t4","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t5","cache_control":{"type":"ephemeral"}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := enforceCacheControlLimit(payload, 4)
|
||||
|
||||
if got := countCacheControls(out); got != 4 {
|
||||
t.Fatalf("cache_control count = %d, want 4", got)
|
||||
}
|
||||
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||
t.Fatalf("tools.0.cache_control should be removed to satisfy max=4")
|
||||
}
|
||||
if !gjson.GetBytes(out, "tools.4.cache_control").Exists() {
|
||||
t.Fatalf("last tool cache_control should be preserved when possible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) {
|
||||
var seenBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
seenBody = bytes.Clone(body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"input_tokens":42}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}}
|
||||
],
|
||||
"system": [
|
||||
{"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}},
|
||||
{"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}}
|
||||
],
|
||||
"messages": [
|
||||
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
|
||||
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-haiku-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens error: %v", err)
|
||||
}
|
||||
|
||||
if len(seenBody) == 0 {
|
||||
t.Fatal("expected count_tokens request body to be captured")
|
||||
}
|
||||
if got := countCacheControls(seenBody); got > 4 {
|
||||
t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got)
|
||||
}
|
||||
if hasTTLOrderingViolation(seenBody) {
|
||||
t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody))
|
||||
}
|
||||
}
|
||||
|
||||
func hasTTLOrderingViolation(payload []byte) bool {
|
||||
seen5m := false
|
||||
violates := false
|
||||
|
||||
checkCC := func(cc gjson.Result) {
|
||||
if !cc.Exists() || violates {
|
||||
return
|
||||
}
|
||||
ttl := cc.Get("ttl").String()
|
||||
if ttl != "1h" {
|
||||
seen5m = true
|
||||
return
|
||||
}
|
||||
if seen5m {
|
||||
violates = true
|
||||
}
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(payload, "tools")
|
||||
if tools.IsArray() {
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
checkCC(tool.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, item gjson.Result) bool {
|
||||
checkCC(item.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if messages.IsArray() {
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
content := msg.Get("content")
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, item gjson.Result) bool {
|
||||
checkCC(item.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
return violates
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_Execute_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteStream_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_CountTokens_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func testClaudeExecutorInvalidCompressedErrorBody(
|
||||
t *testing.T,
|
||||
invoke func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("not-a-valid-gzip-stream"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
err := invoke(executor, auth, payload)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to decode error response body") {
|
||||
t.Fatalf("expected decode failure message, got: %v", err)
|
||||
}
|
||||
if statusProvider, ok := err.(interface{ StatusCode() int }); !ok || statusProvider.StatusCode() != http.StatusBadRequest {
|
||||
t.Fatalf("expected status code 400, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
|
||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||
// userIDPattern matches Claude Code format: user_[64-hex]_account_[uuid]_session_[uuid]
|
||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}_session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||
|
||||
// generateFakeUserID generates a fake user ID in Claude Code format.
|
||||
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
|
||||
// Format: user_[64-hex-chars]_account_[UUID-v4]_session_[UUID-v4]
|
||||
func generateFakeUserID() string {
|
||||
hexBytes := make([]byte, 32)
|
||||
_, _ = rand.Read(hexBytes)
|
||||
hexPart := hex.EncodeToString(hexBytes)
|
||||
uuidPart := uuid.New().String()
|
||||
return "user_" + hexPart + "_account__session_" + uuidPart
|
||||
accountUUID := uuid.New().String()
|
||||
sessionUUID := uuid.New().String()
|
||||
return "user_" + hexPart + "_account_" + accountUUID + "_session_" + sessionUUID
|
||||
}
|
||||
|
||||
// isValidUserID checks if a user ID matches Claude Code format.
|
||||
|
||||
@@ -28,8 +28,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.98.0"
|
||||
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexClientVersion = "0.101.0"
|
||||
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
@@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||
@@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||
}
|
||||
@@ -358,11 +358,10 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -643,7 +642,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
|
||||
@@ -675,6 +673,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
|
||||
return nil
|
||||
}
|
||||
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
|
||||
resetAtTime := time.Unix(resetsAt, 0)
|
||||
if resetAtTime.After(now) {
|
||||
retryAfter := resetAtTime.Sub(now)
|
||||
return &retryAfter
|
||||
}
|
||||
}
|
||||
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
|
||||
retryAfter := time.Duration(resetsInSeconds) * time.Second
|
||||
return &retryAfter
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
65
internal/runtime/executor/codex_executor_retry_test.go
Normal file
65
internal/runtime/executor/codex_executor_retry_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseCodexRetryAfter(t *testing.T) {
|
||||
now := time.Unix(1_700_000_000, 0)
|
||||
|
||||
t.Run("resets_in_seconds", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 123*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers resets_at", func(t *testing.T) {
|
||||
resetAt := now.Add(5 * time.Minute).Unix()
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 5*time.Minute {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback when resets_at is past", func(t *testing.T) {
|
||||
resetAt := now.Add(-1 * time.Minute).Unix()
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
|
||||
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
|
||||
if retryAfter == nil {
|
||||
t.Fatalf("expected retryAfter, got nil")
|
||||
}
|
||||
if *retryAfter != 77*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-429 status code", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
|
||||
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
|
||||
t.Fatalf("expected nil for non-429, got %v", *got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non usage_limit_reached error type", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
|
||||
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
|
||||
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
1408
internal/runtime/executor/codex_websockets_executor.go
Normal file
1408
internal/runtime/executor/codex_websockets_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
@@ -81,7 +80,7 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
|
||||
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(req)
|
||||
applyGeminiCLIHeaders(req, "unknown")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -189,7 +188,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP)
|
||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||
reqHTTP.Header.Set("Accept", "application/json")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
@@ -225,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -256,7 +255,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -334,7 +333,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP)
|
||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||
reqHTTP.Header.Set("Accept", "text/event-stream")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
@@ -382,7 +381,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -441,7 +439,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
if len(lastBody) > 0 {
|
||||
@@ -516,7 +514,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
reqHTTP.Header.Set("Content-Type", "application/json")
|
||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(reqHTTP)
|
||||
applyGeminiCLIHeaders(reqHTTP, baseModel)
|
||||
reqHTTP.Header.Set("Accept", "application/json")
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
@@ -546,7 +544,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
@@ -739,21 +737,11 @@ func stringValue(m map[string]any, key string) string {
|
||||
}
|
||||
|
||||
// applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream.
|
||||
func applyGeminiCLIHeaders(r *http.Request) {
|
||||
var ginHeaders http.Header
|
||||
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "google-api-nodejs-client/9.15.1")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Goog-Api-Client", "gl-node/22.17.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Client-Metadata", geminiCLIClientMetadata())
|
||||
}
|
||||
|
||||
// geminiCLIClientMetadata returns a compact metadata string required by upstream.
|
||||
func geminiCLIClientMetadata() string {
|
||||
// Keep parity with CLI client defaults
|
||||
return "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
// User-Agent is always forced to the GeminiCLI format regardless of the client's value,
|
||||
// so that upstream identifies the request as a native GeminiCLI client.
|
||||
func applyGeminiCLIHeaders(r *http.Request, model string) {
|
||||
r.Header.Set("User-Agent", misc.GeminiCLIUserAgent(model))
|
||||
r.Header.Set("X-Goog-Api-Client", misc.GeminiCLIApiClientHeader)
|
||||
}
|
||||
|
||||
// cliPreviewFallbackOrder returns preview model candidates for a base model.
|
||||
@@ -899,8 +887,7 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
return new(time.Duration(seconds) * time.Second), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini API.
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the Gemini API.
|
||||
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||
|
||||
@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
|
||||
@@ -39,7 +39,8 @@ const (
|
||||
copilotEditorVersion = "vscode/1.107.0"
|
||||
copilotPluginVersion = "copilot-chat/0.35.0"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotOpenAIIntent = "conversation-edits"
|
||||
copilotOpenAIIntent = "conversation-panel"
|
||||
copilotGitHubAPIVer = "2025-04-01"
|
||||
)
|
||||
|
||||
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
||||
@@ -51,8 +52,9 @@ type GitHubCopilotExecutor struct {
|
||||
|
||||
// cachedAPIToken stores a cached Copilot API token with its expiry.
|
||||
type cachedAPIToken struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
token string
|
||||
apiEndpoint string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// NewGitHubCopilotExecutor constructs a new executor instance.
|
||||
@@ -75,7 +77,7 @@ func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxy
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
||||
apiToken, _, errToken := e.ensureAPIToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return errToken
|
||||
}
|
||||
@@ -101,7 +103,7 @@ func (e *GitHubCopilotExecutor) HttpRequest(ctx context.Context, auth *cliproxya
|
||||
|
||||
// Execute handles non-streaming requests to GitHub Copilot.
|
||||
func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
||||
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
}
|
||||
@@ -110,7 +112,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from)
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
|
||||
to := sdktranslator.FromString("openai")
|
||||
if useResponses {
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
@@ -123,6 +125,25 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
|
||||
// Detect vision content before input normalization removes messages
|
||||
hasVision := detectVisionContent(body)
|
||||
|
||||
thinkingProvider := "openai"
|
||||
if useResponses {
|
||||
thinkingProvider = "codex"
|
||||
}
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
if useResponses {
|
||||
body = normalizeGitHubCopilotResponsesInput(body)
|
||||
body = normalizeGitHubCopilotResponsesTools(body)
|
||||
} else {
|
||||
body = normalizeGitHubCopilotChatTools(body)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
@@ -131,7 +152,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
if useResponses {
|
||||
path = githubCopilotResponsesPath
|
||||
}
|
||||
url := githubCopilotBaseURL + path
|
||||
url := baseURL + path
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
@@ -139,7 +160,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
e.applyHeaders(httpReq, apiToken, body)
|
||||
|
||||
// Add Copilot-Vision-Request header if the request contains vision content
|
||||
if detectVisionContent(body) {
|
||||
if hasVision {
|
||||
httpReq.Header.Set("Copilot-Vision-Request", "true")
|
||||
}
|
||||
|
||||
@@ -199,15 +220,20 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
converted := ""
|
||||
if useResponses && from.String() == "claude" {
|
||||
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
||||
} else {
|
||||
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream handles streaming requests to GitHub Copilot.
|
||||
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
||||
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return nil, errToken
|
||||
}
|
||||
@@ -216,7 +242,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from)
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
|
||||
to := sdktranslator.FromString("openai")
|
||||
if useResponses {
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
@@ -229,6 +255,25 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
|
||||
// Detect vision content before input normalization removes messages
|
||||
hasVision := detectVisionContent(body)
|
||||
|
||||
thinkingProvider := "openai"
|
||||
if useResponses {
|
||||
thinkingProvider = "codex"
|
||||
}
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if useResponses {
|
||||
body = normalizeGitHubCopilotResponsesInput(body)
|
||||
body = normalizeGitHubCopilotResponsesTools(body)
|
||||
} else {
|
||||
body = normalizeGitHubCopilotChatTools(body)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
@@ -241,7 +286,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
if useResponses {
|
||||
path = githubCopilotResponsesPath
|
||||
}
|
||||
url := githubCopilotBaseURL + path
|
||||
url := baseURL + path
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -249,7 +294,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
e.applyHeaders(httpReq, apiToken, body)
|
||||
|
||||
// Add Copilot-Vision-Request header if the request contains vision content
|
||||
if detectVisionContent(body) {
|
||||
if hasVision {
|
||||
httpReq.Header.Set("Copilot-Vision-Request", "true")
|
||||
}
|
||||
|
||||
@@ -296,7 +341,6 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
|
||||
go func() {
|
||||
defer close(out)
|
||||
@@ -329,7 +373,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
}
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
var chunks []string
|
||||
if useResponses && from.String() == "claude" {
|
||||
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
||||
} else {
|
||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
}
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -344,7 +393,10 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
}
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: httpResp.Header.Clone(),
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CountTokens is not supported for GitHub Copilot.
|
||||
@@ -376,22 +428,22 @@ func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
|
||||
// ensureAPIToken gets or refreshes the Copilot API token.
|
||||
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
|
||||
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, string, error) {
|
||||
if auth == nil {
|
||||
return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||
return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||
}
|
||||
|
||||
// Get the GitHub access token
|
||||
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||
if accessToken == "" {
|
||||
return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
|
||||
return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
|
||||
}
|
||||
|
||||
// Check for cached API token using thread-safe access
|
||||
e.mu.RLock()
|
||||
if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
|
||||
e.mu.RUnlock()
|
||||
return cached.token, nil
|
||||
return cached.token, cached.apiEndpoint, nil
|
||||
}
|
||||
e.mu.RUnlock()
|
||||
|
||||
@@ -399,7 +451,13 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
|
||||
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
|
||||
apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
|
||||
return "", "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
|
||||
}
|
||||
|
||||
// Use endpoint from token response, fall back to default
|
||||
apiEndpoint := githubCopilotBaseURL
|
||||
if apiToken.Endpoints.API != "" {
|
||||
apiEndpoint = strings.TrimRight(apiToken.Endpoints.API, "/")
|
||||
}
|
||||
|
||||
// Cache the token with thread-safe access
|
||||
@@ -409,12 +467,13 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
|
||||
}
|
||||
e.mu.Lock()
|
||||
e.cache[accessToken] = &cachedAPIToken{
|
||||
token: apiToken.Token,
|
||||
expiresAt: expiresAt,
|
||||
token: apiToken.Token,
|
||||
apiEndpoint: apiEndpoint,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
return apiToken.Token, nil
|
||||
return apiToken.Token, apiEndpoint, nil
|
||||
}
|
||||
|
||||
// applyHeaders sets the required headers for GitHub Copilot API requests.
|
||||
@@ -427,21 +486,50 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
|
||||
r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||
r.Header.Set("X-Github-Api-Version", copilotGitHubAPIVer)
|
||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||
|
||||
initiator := "user"
|
||||
if len(body) > 0 {
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
arr := messages.Array()
|
||||
if len(arr) > 0 {
|
||||
lastRole := arr[len(arr)-1].Get("role").String()
|
||||
if lastRole != "" && lastRole != "user" {
|
||||
initiator = "agent"
|
||||
}
|
||||
if role := detectLastConversationRole(body); role == "assistant" || role == "tool" {
|
||||
initiator = "agent"
|
||||
}
|
||||
r.Header.Set("X-Initiator", initiator)
|
||||
}
|
||||
|
||||
func detectLastConversationRole(body []byte) string {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
arr := messages.Array()
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
if role := arr[i].Get("role").String(); role != "" {
|
||||
return role
|
||||
}
|
||||
}
|
||||
}
|
||||
r.Header.Set("X-Initiator", initiator)
|
||||
|
||||
if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() {
|
||||
arr := inputs.Array()
|
||||
for i := len(arr) - 1; i >= 0; i-- {
|
||||
item := arr[i]
|
||||
|
||||
// Most Responses input items carry a top-level role.
|
||||
if role := item.Get("role").String(); role != "" {
|
||||
return role
|
||||
}
|
||||
|
||||
switch item.Get("type").String() {
|
||||
case "function_call", "function_call_arguments":
|
||||
return "assistant"
|
||||
case "function_call_output", "function_call_response", "tool_result":
|
||||
return "tool"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// detectVisionContent checks if the request body contains vision/image content.
|
||||
@@ -483,8 +571,12 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte
|
||||
return body
|
||||
}
|
||||
|
||||
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool {
|
||||
return sourceFormat.String() == "openai-response"
|
||||
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
|
||||
if sourceFormat.String() == "openai-response" {
|
||||
return true
|
||||
}
|
||||
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
|
||||
return strings.Contains(baseModel, "codex")
|
||||
}
|
||||
|
||||
// flattenAssistantContent converts assistant message content from array format
|
||||
@@ -504,6 +596,17 @@ func flattenAssistantContent(body []byte) []byte {
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
// Skip flattening if the content contains non-text blocks (tool_use, thinking, etc.)
|
||||
hasNonText := false
|
||||
for _, part := range content.Array() {
|
||||
if t := part.Get("type").String(); t != "" && t != "text" {
|
||||
hasNonText = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasNonText {
|
||||
continue
|
||||
}
|
||||
var textParts []string
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
@@ -519,6 +622,644 @@ func flattenAssistantContent(body []byte) []byte {
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotChatTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
filtered := "[]"
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
if tool.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
|
||||
}
|
||||
|
||||
toolChoice := gjson.GetBytes(body, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return body
|
||||
}
|
||||
if toolChoice.Type == gjson.String {
|
||||
switch toolChoice.String() {
|
||||
case "auto", "none", "required":
|
||||
return body
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
||||
input := gjson.GetBytes(body, "input")
|
||||
if input.Exists() {
|
||||
// If input is already a string or array, keep it as-is.
|
||||
if input.Type == gjson.String || input.IsArray() {
|
||||
return body
|
||||
}
|
||||
// Non-string/non-array input: stringify as fallback.
|
||||
body, _ = sjson.SetBytes(body, "input", input.Raw)
|
||||
return body
|
||||
}
|
||||
|
||||
// Convert Claude messages format to OpenAI Responses API input array.
|
||||
// This preserves the conversation structure (roles, tool calls, tool results)
|
||||
// which is critical for multi-turn tool-use conversations.
|
||||
inputArr := "[]"
|
||||
|
||||
// System messages → developer role
|
||||
if system := gjson.GetBytes(body, "system"); system.Exists() {
|
||||
var systemParts []string
|
||||
if system.IsArray() {
|
||||
for _, part := range system.Array() {
|
||||
if txt := part.Get("text").String(); txt != "" {
|
||||
systemParts = append(systemParts, txt)
|
||||
}
|
||||
}
|
||||
} else if system.Type == gjson.String {
|
||||
systemParts = append(systemParts, system.String())
|
||||
}
|
||||
if len(systemParts) > 0 {
|
||||
msg := `{"type":"message","role":"developer","content":[]}`
|
||||
for _, txt := range systemParts {
|
||||
part := `{"type":"input_text","text":""}`
|
||||
part, _ = sjson.Set(part, "text", txt)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
inputArr, _ = sjson.SetRaw(inputArr, "-1", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Messages → structured input items
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
for _, msg := range messages.Array() {
|
||||
role := msg.Get("role").String()
|
||||
content := msg.Get("content")
|
||||
|
||||
if !content.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Simple string content
|
||||
if content.Type == gjson.String {
|
||||
textType := "input_text"
|
||||
if role == "assistant" {
|
||||
textType = "output_text"
|
||||
}
|
||||
item := `{"type":"message","role":"","content":[]}`
|
||||
item, _ = sjson.Set(item, "role", role)
|
||||
part := fmt.Sprintf(`{"type":"%s","text":""}`, textType)
|
||||
part, _ = sjson.Set(part, "text", content.String())
|
||||
item, _ = sjson.SetRaw(item, "content.-1", part)
|
||||
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
|
||||
continue
|
||||
}
|
||||
|
||||
if !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Array content: split into message parts vs tool items
|
||||
var msgParts []string
|
||||
for _, c := range content.Array() {
|
||||
cType := c.Get("type").String()
|
||||
switch cType {
|
||||
case "text":
|
||||
textType := "input_text"
|
||||
if role == "assistant" {
|
||||
textType = "output_text"
|
||||
}
|
||||
part := fmt.Sprintf(`{"type":"%s","text":""}`, textType)
|
||||
part, _ = sjson.Set(part, "text", c.Get("text").String())
|
||||
msgParts = append(msgParts, part)
|
||||
case "image":
|
||||
source := c.Get("source")
|
||||
if source.Exists() {
|
||||
data := source.Get("data").String()
|
||||
if data == "" {
|
||||
data = source.Get("base64").String()
|
||||
}
|
||||
mediaType := source.Get("media_type").String()
|
||||
if mediaType == "" {
|
||||
mediaType = source.Get("mime_type").String()
|
||||
}
|
||||
if mediaType == "" {
|
||||
mediaType = "application/octet-stream"
|
||||
}
|
||||
if data != "" {
|
||||
part := `{"type":"input_image","image_url":""}`
|
||||
part, _ = sjson.Set(part, "image_url", fmt.Sprintf("data:%s;base64,%s", mediaType, data))
|
||||
msgParts = append(msgParts, part)
|
||||
}
|
||||
}
|
||||
case "tool_use":
|
||||
// Flush any accumulated message parts first
|
||||
if len(msgParts) > 0 {
|
||||
item := `{"type":"message","role":"","content":[]}`
|
||||
item, _ = sjson.Set(item, "role", role)
|
||||
for _, p := range msgParts {
|
||||
item, _ = sjson.SetRaw(item, "content.-1", p)
|
||||
}
|
||||
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
|
||||
msgParts = nil
|
||||
}
|
||||
fc := `{"type":"function_call","call_id":"","name":"","arguments":""}`
|
||||
fc, _ = sjson.Set(fc, "call_id", c.Get("id").String())
|
||||
fc, _ = sjson.Set(fc, "name", c.Get("name").String())
|
||||
if inputRaw := c.Get("input"); inputRaw.Exists() {
|
||||
fc, _ = sjson.Set(fc, "arguments", inputRaw.Raw)
|
||||
}
|
||||
inputArr, _ = sjson.SetRaw(inputArr, "-1", fc)
|
||||
case "tool_result":
|
||||
// Flush any accumulated message parts first
|
||||
if len(msgParts) > 0 {
|
||||
item := `{"type":"message","role":"","content":[]}`
|
||||
item, _ = sjson.Set(item, "role", role)
|
||||
for _, p := range msgParts {
|
||||
item, _ = sjson.SetRaw(item, "content.-1", p)
|
||||
}
|
||||
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
|
||||
msgParts = nil
|
||||
}
|
||||
fco := `{"type":"function_call_output","call_id":"","output":""}`
|
||||
fco, _ = sjson.Set(fco, "call_id", c.Get("tool_use_id").String())
|
||||
// Extract output text
|
||||
resultContent := c.Get("content")
|
||||
if resultContent.Type == gjson.String {
|
||||
fco, _ = sjson.Set(fco, "output", resultContent.String())
|
||||
} else if resultContent.IsArray() {
|
||||
var resultParts []string
|
||||
for _, rc := range resultContent.Array() {
|
||||
if txt := rc.Get("text").String(); txt != "" {
|
||||
resultParts = append(resultParts, txt)
|
||||
}
|
||||
}
|
||||
fco, _ = sjson.Set(fco, "output", strings.Join(resultParts, "\n"))
|
||||
} else if resultContent.Exists() {
|
||||
fco, _ = sjson.Set(fco, "output", resultContent.String())
|
||||
}
|
||||
inputArr, _ = sjson.SetRaw(inputArr, "-1", fco)
|
||||
case "thinking":
|
||||
// Skip thinking blocks - not part of the API input
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining message parts
|
||||
if len(msgParts) > 0 {
|
||||
item := `{"type":"message","role":"","content":[]}`
|
||||
item, _ = sjson.Set(item, "role", role)
|
||||
for _, p := range msgParts {
|
||||
item, _ = sjson.SetRaw(item, "content.-1", p)
|
||||
}
|
||||
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, _ = sjson.SetRawBytes(body, "input", []byte(inputArr))
|
||||
// Remove messages/system since we've converted them to input
|
||||
body, _ = sjson.DeleteBytes(body, "messages")
|
||||
body, _ = sjson.DeleteBytes(body, "system")
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
filtered := "[]"
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
// Accept OpenAI format (type="function") and Claude format
|
||||
// (no type field, but has top-level name + input_schema).
|
||||
if toolType != "" && toolType != "function" {
|
||||
continue
|
||||
}
|
||||
name := tool.Get("name").String()
|
||||
if name == "" {
|
||||
name = tool.Get("function.name").String()
|
||||
}
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
normalized := `{"type":"function","name":""}`
|
||||
normalized, _ = sjson.Set(normalized, "name", name)
|
||||
if desc := tool.Get("description").String(); desc != "" {
|
||||
normalized, _ = sjson.Set(normalized, "description", desc)
|
||||
} else if desc = tool.Get("function.description").String(); desc != "" {
|
||||
normalized, _ = sjson.Set(normalized, "description", desc)
|
||||
}
|
||||
if params := tool.Get("parameters"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
} else if params = tool.Get("function.parameters"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
} else if params = tool.Get("input_schema"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
}
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", normalized)
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
|
||||
}
|
||||
|
||||
toolChoice := gjson.GetBytes(body, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return body
|
||||
}
|
||||
if toolChoice.Type == gjson.String {
|
||||
switch toolChoice.String() {
|
||||
case "auto", "none", "required":
|
||||
return body
|
||||
default:
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
}
|
||||
if toolChoice.Type == gjson.JSON {
|
||||
choiceType := toolChoice.Get("type").String()
|
||||
if choiceType == "function" {
|
||||
name := toolChoice.Get("name").String()
|
||||
if name == "" {
|
||||
name = toolChoice.Get("function.name").String()
|
||||
}
|
||||
if name != "" {
|
||||
normalized := `{"type":"function","name":""}`
|
||||
normalized, _ = sjson.Set(normalized, "name", name)
|
||||
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized))
|
||||
return body
|
||||
}
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
|
||||
func collectTextFromNode(node gjson.Result) string {
|
||||
if !node.Exists() {
|
||||
return ""
|
||||
}
|
||||
if node.Type == gjson.String {
|
||||
return node.String()
|
||||
}
|
||||
if node.IsArray() {
|
||||
var parts []string
|
||||
for _, item := range node.Array() {
|
||||
if item.Type == gjson.String {
|
||||
if text := item.String(); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if text := item.Get("text").String(); text != "" {
|
||||
parts = append(parts, text)
|
||||
continue
|
||||
}
|
||||
if nested := collectTextFromNode(item.Get("content")); nested != "" {
|
||||
parts = append(parts, nested)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
if node.Type == gjson.JSON {
|
||||
if text := node.Get("text").String(); text != "" {
|
||||
return text
|
||||
}
|
||||
if nested := collectTextFromNode(node.Get("content")); nested != "" {
|
||||
return nested
|
||||
}
|
||||
return node.Raw
|
||||
}
|
||||
return node.String()
|
||||
}
|
||||
|
||||
type githubCopilotResponsesStreamToolState struct {
|
||||
Index int
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
type githubCopilotResponsesStreamState struct {
|
||||
MessageStarted bool
|
||||
MessageStopSent bool
|
||||
TextBlockStarted bool
|
||||
TextBlockIndex int
|
||||
NextContentIndex int
|
||||
HasToolUse bool
|
||||
ReasoningActive bool
|
||||
ReasoningIndex int
|
||||
OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState
|
||||
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
||||
root := gjson.ParseBytes(data)
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", root.Get("model").String())
|
||||
|
||||
hasToolUse := false
|
||||
if output := root.Get("output"); output.Exists() && output.IsArray() {
|
||||
for _, item := range output.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "reasoning":
|
||||
var thinkingText string
|
||||
if summary := item.Get("summary"); summary.Exists() && summary.IsArray() {
|
||||
var parts []string
|
||||
for _, part := range summary.Array() {
|
||||
if txt := part.Get("text").String(); txt != "" {
|
||||
parts = append(parts, txt)
|
||||
}
|
||||
}
|
||||
thinkingText = strings.Join(parts, "")
|
||||
}
|
||||
if thinkingText == "" {
|
||||
if content := item.Get("content"); content.Exists() && content.IsArray() {
|
||||
var parts []string
|
||||
for _, part := range content.Array() {
|
||||
if txt := part.Get("text").String(); txt != "" {
|
||||
parts = append(parts, txt)
|
||||
}
|
||||
}
|
||||
thinkingText = strings.Join(parts, "")
|
||||
}
|
||||
}
|
||||
if thinkingText != "" {
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingText)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
}
|
||||
case "message":
|
||||
if content := item.Get("content"); content.Exists() && content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() != "output_text" {
|
||||
continue
|
||||
}
|
||||
text := part.Get("text").String()
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
hasToolUse = true
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolID := item.Get("call_id").String()
|
||||
if toolID == "" {
|
||||
toolID = item.Get("id").String()
|
||||
}
|
||||
toolUse, _ = sjson.Set(toolUse, "id", toolID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String())
|
||||
if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) {
|
||||
argObj := gjson.Parse(args)
|
||||
if argObj.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "content.-1", toolUse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inputTokens := root.Get("usage.input_tokens").Int()
|
||||
outputTokens := root.Get("usage.output_tokens").Int()
|
||||
cachedTokens := root.Get("usage.input_tokens_details.cached_tokens").Int()
|
||||
if cachedTokens > 0 && inputTokens >= cachedTokens {
|
||||
inputTokens -= cachedTokens
|
||||
}
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
if hasToolUse {
|
||||
out, _ = sjson.Set(out, "stop_reason", "tool_use")
|
||||
} else if sr := root.Get("stop_reason").String(); sr == "max_tokens" || sr == "stop" {
|
||||
out, _ = sjson.Set(out, "stop_reason", sr)
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
|
||||
if *param == nil {
|
||||
*param = &githubCopilotResponsesStreamState{
|
||||
TextBlockIndex: -1,
|
||||
OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState),
|
||||
ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState),
|
||||
}
|
||||
}
|
||||
state := (*param).(*githubCopilotResponsesStreamState)
|
||||
|
||||
if !bytes.HasPrefix(line, dataTag) {
|
||||
return nil
|
||||
}
|
||||
payload := bytes.TrimSpace(line[5:])
|
||||
if bytes.Equal(payload, []byte("[DONE]")) {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return nil
|
||||
}
|
||||
|
||||
event := gjson.GetBytes(payload, "type").String()
|
||||
results := make([]string, 0, 4)
|
||||
ensureMessageStart := func() {
|
||||
if state.MessageStarted {
|
||||
return
|
||||
}
|
||||
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
|
||||
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
|
||||
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
|
||||
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
|
||||
state.MessageStarted = true
|
||||
}
|
||||
startTextBlockIfNeeded := func() {
|
||||
if state.TextBlockStarted {
|
||||
return
|
||||
}
|
||||
if state.TextBlockIndex < 0 {
|
||||
state.TextBlockIndex = state.NextContentIndex
|
||||
state.NextContentIndex++
|
||||
}
|
||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
state.TextBlockStarted = true
|
||||
}
|
||||
stopTextBlockIfNeeded := func() {
|
||||
if !state.TextBlockStarted {
|
||||
return
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
state.TextBlockStarted = false
|
||||
state.TextBlockIndex = -1
|
||||
}
|
||||
resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState {
|
||||
if itemID != "" {
|
||||
if tool, ok := state.ItemIDToTool[itemID]; ok {
|
||||
return tool
|
||||
}
|
||||
}
|
||||
if tool, ok := state.OutputIndexToTool[outputIndex]; ok {
|
||||
if itemID != "" {
|
||||
state.ItemIDToTool[itemID] = tool
|
||||
}
|
||||
return tool
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch event {
|
||||
case "response.created":
|
||||
ensureMessageStart()
|
||||
case "response.output_text.delta":
|
||||
ensureMessageStart()
|
||||
startTextBlockIfNeeded()
|
||||
delta := gjson.GetBytes(payload, "delta").String()
|
||||
if delta != "" {
|
||||
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
|
||||
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
|
||||
}
|
||||
case "response.reasoning_summary_part.added":
|
||||
ensureMessageStart()
|
||||
state.ReasoningActive = true
|
||||
state.ReasoningIndex = state.NextContentIndex
|
||||
state.NextContentIndex++
|
||||
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
||||
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
|
||||
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
|
||||
case "response.reasoning_summary_text.delta":
|
||||
if state.ReasoningActive {
|
||||
delta := gjson.GetBytes(payload, "delta").String()
|
||||
if delta != "" {
|
||||
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
||||
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
|
||||
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
|
||||
}
|
||||
}
|
||||
case "response.reasoning_summary_part.done":
|
||||
if state.ReasoningActive {
|
||||
thinkingStop := `{"type":"content_block_stop","index":0}`
|
||||
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
|
||||
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
|
||||
state.ReasoningActive = false
|
||||
}
|
||||
case "response.output_item.added":
|
||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
ensureMessageStart()
|
||||
stopTextBlockIfNeeded()
|
||||
state.HasToolUse = true
|
||||
tool := &githubCopilotResponsesStreamToolState{
|
||||
Index: state.NextContentIndex,
|
||||
ID: gjson.GetBytes(payload, "item.call_id").String(),
|
||||
Name: gjson.GetBytes(payload, "item.name").String(),
|
||||
}
|
||||
if tool.ID == "" {
|
||||
tool.ID = gjson.GetBytes(payload, "item.id").String()
|
||||
}
|
||||
state.NextContentIndex++
|
||||
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
|
||||
state.OutputIndexToTool[outputIndex] = tool
|
||||
if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" {
|
||||
state.ItemIDToTool[itemID] = tool
|
||||
}
|
||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
case "response.output_item.delta":
|
||||
item := gjson.GetBytes(payload, "item")
|
||||
if item.Get("type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
|
||||
if tool == nil {
|
||||
break
|
||||
}
|
||||
partial := gjson.GetBytes(payload, "delta").String()
|
||||
if partial == "" {
|
||||
partial = item.Get("arguments").String()
|
||||
}
|
||||
if partial == "" {
|
||||
break
|
||||
}
|
||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
||||
case "response.function_call_arguments.delta":
|
||||
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
|
||||
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
|
||||
itemID := gjson.GetBytes(payload, "item_id").String()
|
||||
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
|
||||
tool := resolveTool(itemID, outputIndex)
|
||||
if tool == nil {
|
||||
break
|
||||
}
|
||||
partial := gjson.GetBytes(payload, "delta").String()
|
||||
if partial == "" {
|
||||
break
|
||||
}
|
||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
||||
case "response.output_item.done":
|
||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
|
||||
if tool == nil {
|
||||
break
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
case "response.completed":
|
||||
ensureMessageStart()
|
||||
stopTextBlockIfNeeded()
|
||||
if !state.MessageStopSent {
|
||||
stopReason := "end_turn"
|
||||
if state.HasToolUse {
|
||||
stopReason = "tool_use"
|
||||
} else if sr := gjson.GetBytes(payload, "response.stop_reason").String(); sr == "max_tokens" || sr == "stop" {
|
||||
stopReason = sr
|
||||
}
|
||||
inputTokens := gjson.GetBytes(payload, "response.usage.input_tokens").Int()
|
||||
outputTokens := gjson.GetBytes(payload, "response.usage.output_tokens").Int()
|
||||
cachedTokens := gjson.GetBytes(payload, "response.usage.input_tokens_details.cached_tokens").Int()
|
||||
if cachedTokens > 0 && inputTokens >= cachedTokens {
|
||||
inputTokens -= cachedTokens
|
||||
}
|
||||
messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason)
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", inputTokens)
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
|
||||
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
state.MessageStopSent = true
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// isHTTPSuccess checks if the status code indicates success (2xx).
|
||||
func isHTTPSuccess(statusCode int) bool {
|
||||
return statusCode >= 200 && statusCode < 300
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -52,3 +55,312 @@ func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected openai-response source to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
|
||||
t.Fatal("expected codex model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if !in.IsArray() {
|
||||
t.Fatalf("input type = %v, want array", in.Type)
|
||||
}
|
||||
raw := in.Raw
|
||||
if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
|
||||
t.Fatalf("input = %s, want structured array with all texts", raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "messages").Exists() {
|
||||
t.Fatal("messages should be removed after conversion")
|
||||
}
|
||||
if gjson.GetBytes(got, "system").Exists() {
|
||||
t.Fatal("system should be removed after conversion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":{"foo":"bar"}}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if in.Type != gjson.String {
|
||||
t.Fatalf("input type = %v, want string", in.Type)
|
||||
}
|
||||
if !strings.Contains(in.String(), "foo") {
|
||||
t.Fatalf("input = %q, want stringified object", in.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("name").String() != "sum" {
|
||||
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("tools len = %d, want 2", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
if tools[0].Get("name").String() != "Bash" {
|
||||
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
|
||||
}
|
||||
if tools[0].Get("description").String() != "Run commands" {
|
||||
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be set from input_schema")
|
||||
}
|
||||
if tools[0].Get("parameters.properties.command").Exists() != true {
|
||||
t.Fatal("expected parameters.properties.command to exist")
|
||||
}
|
||||
if tools[1].Get("name").String() != "Read" {
|
||||
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
|
||||
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
|
||||
}
|
||||
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
|
||||
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function"}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "type").String() != "message" {
|
||||
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.type").String() != "text" {
|
||||
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.text").String() != "hello" {
|
||||
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "content.0.type").String() != "tool_use" {
|
||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.name").String() != "sum" {
|
||||
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
|
||||
}
|
||||
if gjson.Get(out, "stop_reason").String() != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
var param any
|
||||
|
||||
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
||||
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
|
||||
t.Fatalf("created events = %#v, want message_start", created)
|
||||
}
|
||||
|
||||
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
||||
joinedDelta := strings.Join(delta, "")
|
||||
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
||||
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
||||
}
|
||||
|
||||
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
||||
joinedCompleted := strings.Join(completed, "")
|
||||
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for X-Initiator detection logic (Problem L) ---
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserWhenLastRoleIsUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Last role governs the initiator decision.
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastAssistantMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"Hi"}]},{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello"}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last role is assistant)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastUserMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastFunctionCallOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"Use tool"}]},{"type":"function_call","call_id":"c1","name":"Read","arguments":"{}"},{"type":"function_call_output","call_id":"c1","output":"ok"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last item maps to tool role)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for x-github-api-version header (Problem M) ---
|
||||
|
||||
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
e.applyHeaders(req, "token", nil)
|
||||
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
|
||||
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for vision detection (Problem P) ---
|
||||
|
||||
func TestDetectVisionContent_WithImageURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||
if !detectVisionContent(body) {
|
||||
t.Fatal("expected vision content to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_WithImageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
|
||||
if !detectVisionContent(body) {
|
||||
t.Fatal("expected image type to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_NoVision(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
if detectVisionContent(body) {
|
||||
t.Fatal("expected no vision content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_NoMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
// After Responses API normalization, messages is removed — detection should return false
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
|
||||
if detectVisionContent(body) {
|
||||
t.Fatal("expected no vision content when messages field is absent")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request.
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
460
internal/runtime/executor/kilo_executor.go
Normal file
460
internal/runtime/executor/kilo_executor.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// KiloExecutor handles requests to Kilo API.
|
||||
type KiloExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKiloExecutor creates a new Kilo executor instance.
|
||||
func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
|
||||
return &KiloExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (e *KiloExecutor) Identifier() string { return "kilo" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request before execution.
|
||||
func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
accessToken, _ := kiloCredentials(auth)
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest executes a raw HTTP request.
|
||||
func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("kilo executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request.
|
||||
func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return resp, fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
endpoint := "/api/openrouter/chat/completions"
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := "https://api.kilo.ai" + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request.
|
||||
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
endpoint := "/api/openrouter/chat/completions"
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := "https://api.kilo.ai" + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
httpResp.Body.Close()
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: httpResp.Header.Clone(),
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh validates the Kilo token.
|
||||
func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("missing auth")
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// CountTokens returns the token count for the given request.
|
||||
func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
|
||||
}
|
||||
|
||||
// kiloCredentials extracts access token and other info from auth.
|
||||
func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
|
||||
if auth == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Prefer kilocode specific keys, then fall back to generic keys.
|
||||
// Check metadata first, then attributes.
|
||||
if auth.Metadata != nil {
|
||||
if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
} else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
}
|
||||
|
||||
if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
|
||||
orgID = org
|
||||
} else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
|
||||
orgID = org
|
||||
}
|
||||
}
|
||||
|
||||
if accessToken == "" && auth.Attributes != nil {
|
||||
if token := auth.Attributes["kilocodeToken"]; token != "" {
|
||||
accessToken = token
|
||||
} else if token := auth.Attributes["access_token"]; token != "" {
|
||||
accessToken = token
|
||||
}
|
||||
}
|
||||
|
||||
if orgID == "" && auth.Attributes != nil {
|
||||
if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
|
||||
orgID = org
|
||||
} else if org := auth.Attributes["organization_id"]; org != "" {
|
||||
orgID = org
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, orgID
|
||||
}
|
||||
|
||||
// FetchKiloModels fetches models from Kilo API.
|
||||
func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
|
||||
if err != nil {
|
||||
log.Warnf("kilo: failed to create model fetch request: %v", err)
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
req.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
req.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Warnf("kilo: fetch models canceled: %v", err)
|
||||
} else {
|
||||
log.Warnf("kilo: using static models (API fetch failed: %v)", err)
|
||||
}
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Warnf("kilo: failed to read models response: %v", err)
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(body, "data")
|
||||
if !result.Exists() {
|
||||
// Try root if data field is missing
|
||||
result = gjson.ParseBytes(body)
|
||||
if !result.IsArray() {
|
||||
log.Debugf("kilo: response body: %s", string(body))
|
||||
log.Warn("kilo: invalid API response format (expected array or data field with array)")
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicModels []*registry.ModelInfo
|
||||
now := time.Now().Unix()
|
||||
count := 0
|
||||
totalCount := 0
|
||||
|
||||
result.ForEach(func(key, value gjson.Result) bool {
|
||||
totalCount++
|
||||
id := value.Get("id").String()
|
||||
pIdxResult := value.Get("preferredIndex")
|
||||
preferredIndex := pIdxResult.Int()
|
||||
|
||||
// Filter models where preferredIndex > 0 (Kilo-curated models)
|
||||
if preferredIndex <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
|
||||
isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
|
||||
if !isFree {
|
||||
// Check pricing as fallback
|
||||
promptPricing := value.Get("pricing.prompt").String()
|
||||
if promptPricing == "0" || promptPricing == "0.0" {
|
||||
isFree = true
|
||||
}
|
||||
}
|
||||
|
||||
if !isFree {
|
||||
log.Debugf("kilo: skipping curated paid model: %s", id)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
|
||||
|
||||
dynamicModels = append(dynamicModels, ®istry.ModelInfo{
|
||||
ID: id,
|
||||
DisplayName: value.Get("name").String(),
|
||||
ContextLength: int(value.Get("context_length").Int()),
|
||||
OwnedBy: "kilo",
|
||||
Type: "kilo",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
})
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
|
||||
if count == 0 && totalCount > 0 {
|
||||
log.Warn("kilo: no curated free models found (check API response fields)")
|
||||
}
|
||||
|
||||
staticModels := registry.GetKiloModels()
|
||||
// Always include kilo/auto (first static model)
|
||||
allModels := append(staticModels[:1], dynamicModels...)
|
||||
|
||||
return allModels
|
||||
}
|
||||
@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for Kimi requests.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user