mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-27 22:27:28 +00:00
Compare commits
188 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8b89f34f4 | ||
|
|
1fa094dac6 | ||
|
|
f55754621f | ||
|
|
ac26e7db43 | ||
|
|
10b824fcac | ||
|
|
7dccc7ba2f | ||
|
|
70c90687fd | ||
|
|
8144ffd5c8 | ||
|
|
6b45d311ec | ||
|
|
7386a70724 | ||
|
|
1821bf7051 | ||
|
|
d42b5d4e78 | ||
|
|
1b7447b682 | ||
|
|
40dee4453a | ||
|
|
8902e1cccb | ||
|
|
de5fe71478 | ||
|
|
dcfbec2990 | ||
|
|
c95620f90e | ||
|
|
9613f0b3f9 | ||
|
|
274f29e26b | ||
|
|
c8e79c3787 | ||
|
|
8afef43887 | ||
|
|
c1083cbfc6 | ||
|
|
1e6bc81cfd | ||
|
|
1a149475e0 | ||
|
|
e5166841db | ||
|
|
19c52bcb60 | ||
|
|
bb9b2d1758 | ||
|
|
7fa527193c | ||
|
|
ed0eb51b4d | ||
|
|
0e4f669c8b | ||
|
|
76c064c729 | ||
|
|
d2f652f436 | ||
|
|
6a452a54d5 | ||
|
|
9e5693e74f | ||
|
|
528b1a2307 | ||
|
|
0cc978ec1d | ||
|
|
d312422ab4 | ||
|
|
fee736933b | ||
|
|
09c92aa0b5 | ||
|
|
8c67b3ae64 | ||
|
|
000e4ceb4e | ||
|
|
5c99846ecf | ||
|
|
d475aaba96 | ||
|
|
1dc4ecb1b8 | ||
|
|
1315f710f5 | ||
|
|
96f55570f7 | ||
|
|
0906aeca87 | ||
|
|
7333619f15 | ||
|
|
97c0487add | ||
|
|
2db8df8e38 | ||
|
|
a576088d5f | ||
|
|
66ff916838 | ||
|
|
7b0453074e | ||
|
|
a000eb523d | ||
|
|
18a4fedc7f | ||
|
|
5d6cdccda0 | ||
|
|
1b7f4ac3e1 | ||
|
|
afc1a5b814 | ||
|
|
7ed38db54f | ||
|
|
28c10f4e69 | ||
|
|
6e12441a3b | ||
|
|
65c439c18d | ||
|
|
0ed2d16596 | ||
|
|
db335ac616 | ||
|
|
f3c59165d7 | ||
|
|
e6690cb447 | ||
|
|
35907416b8 | ||
|
|
e8bb350467 | ||
|
|
5331d51f27 | ||
|
|
755ca75879 | ||
|
|
2398ebad55 | ||
|
|
c1bf298216 | ||
|
|
e005208d76 | ||
|
|
d1df70d02f | ||
|
|
f81acd0760 | ||
|
|
636da4c932 | ||
|
|
cccb77b552 | ||
|
|
2bd646ad70 | ||
|
|
52c1fa025e | ||
|
|
680105f84d | ||
|
|
f7069e9548 | ||
|
|
7275e99b41 | ||
|
|
c28b65f849 | ||
|
|
793840cdb4 | ||
|
|
8f421de532 | ||
|
|
be2dd60ee7 | ||
|
|
ea3e0b713e | ||
|
|
8179d5a8a4 | ||
|
|
6fa7abe434 | ||
|
|
5135c22cd6 | ||
|
|
1e27990561 | ||
|
|
e1e9fc43c1 | ||
|
|
b2921518ac | ||
|
|
dd64adbeeb | ||
|
|
616d41c06a | ||
|
|
e0e337aeb9 | ||
|
|
d52839fced | ||
|
|
4022e69651 | ||
|
|
56073ded69 | ||
|
|
9738a53f49 | ||
|
|
be3f8dbf7e | ||
|
|
9c6c3612a8 | ||
|
|
19e1a4447a | ||
|
|
7c2ad4cda2 | ||
|
|
fb95813fbf | ||
|
|
db63f9b5d6 | ||
|
|
25f6c4a250 | ||
|
|
b24ae74216 | ||
|
|
59ad8f40dc | ||
|
|
ff03dc6a2c | ||
|
|
dc7187ca5b | ||
|
|
b1dcff778c | ||
|
|
cef2aeeb08 | ||
|
|
bcd1e8cc34 | ||
|
|
198b3f4a40 | ||
|
|
9fee7f488e | ||
|
|
1b46d39b8b | ||
|
|
c1241a98e2 | ||
|
|
8d8f5970ee | ||
|
|
f90120f846 | ||
|
|
0b94d36c4a | ||
|
|
152c310bb7 | ||
|
|
f6bbca35ab | ||
|
|
c8cee6a209 | ||
|
|
b5701f416b | ||
|
|
4b1a404fcb | ||
|
|
b93cce5412 | ||
|
|
c6cb24039d | ||
|
|
5382408489 | ||
|
|
67669196ed | ||
|
|
58fd9bf964 | ||
|
|
7b3dfc67bc | ||
|
|
cdd24052d3 | ||
|
|
5da0decef6 | ||
|
|
733fd8edab | ||
|
|
af27f2b8bc | ||
|
|
2e1925d762 | ||
|
|
77254bd074 | ||
|
|
5b6342e6ac | ||
|
|
3960c93d51 | ||
|
|
339a81b650 | ||
|
|
560c020477 | ||
|
|
aec65e3be3 | ||
|
|
f44f0702f8 | ||
|
|
b76b79068f | ||
|
|
34c8ccb961 | ||
|
|
d08e164af3 | ||
|
|
8178efaeda | ||
|
|
86d5db472a | ||
|
|
020d36f6e8 | ||
|
|
1db23979e8 | ||
|
|
c3d5dbe96f | ||
|
|
5484489406 | ||
|
|
0ac52da460 | ||
|
|
817cebb321 | ||
|
|
683f3709d6 | ||
|
|
dbd42a42b2 | ||
|
|
ec24baf757 | ||
|
|
dea3e74d35 | ||
|
|
a6c3042e34 | ||
|
|
861537c9bd | ||
|
|
8c92cb0883 | ||
|
|
89d7be9525 | ||
|
|
2b79d7f22f | ||
|
|
2bb686f594 | ||
|
|
163fe287ce | ||
|
|
70988d387b | ||
|
|
52058a1659 | ||
|
|
df5595a0c9 | ||
|
|
ddaa9d2436 | ||
|
|
7b7b258c38 | ||
|
|
a00f774f5a | ||
|
|
9daf1ba8b5 | ||
|
|
76f2359637 | ||
|
|
dcb1c9be8a | ||
|
|
a24f4ace78 | ||
|
|
c631df8c3b | ||
|
|
54c3eb1b1e | ||
|
|
bb28cd26ad | ||
|
|
046865461e | ||
|
|
cf74ed2f0c | ||
|
|
c3762328a5 | ||
|
|
e333fbea3d | ||
|
|
efbe36d1d4 | ||
|
|
8553cfa40e | ||
|
|
30d5c95b26 | ||
|
|
d1e3195e6f |
14
.github/workflows/docker-image.yml
vendored
14
.github/workflows/docker-image.yml
vendored
@@ -16,6 +16,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -25,7 +29,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Build and push (amd64)
|
- name: Build and push (amd64)
|
||||||
@@ -47,6 +51,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -56,7 +64,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Build and push (arm64)
|
- name: Build and push (arm64)
|
||||||
@@ -90,7 +98,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Create and push multi-arch manifests
|
- name: Create and push multi-arch manifests
|
||||||
|
|||||||
4
.github/workflows/pr-test-build.yml
vendored
4
.github/workflows/pr-test-build.yml
vendored
@@ -12,6 +12,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
9
.github/workflows/release.yaml
vendored
9
.github/workflows/release.yaml
vendored
@@ -16,6 +16,10 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- run: git fetch --force --tags
|
- run: git fetch --force --tags
|
||||||
- uses: actions/setup-go@v4
|
- uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
@@ -23,15 +27,14 @@ jobs:
|
|||||||
cache: true
|
cache: true
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
VERSION=$(git describe --tags --always --dirty)
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo "VERSION=${VERSION}" >> $GITHUB_ENV
|
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- uses: goreleaser/goreleaser-action@v4
|
- uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
distribution: goreleaser
|
distribution: goreleaser
|
||||||
version: latest
|
version: latest
|
||||||
args: release --clean
|
args: release --clean --skip=validate
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
VERSION: ${{ env.VERSION }}
|
VERSION: ${{ env.VERSION }}
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
|||||||
# Binaries
|
# Binaries
|
||||||
cli-proxy-api
|
cli-proxy-api
|
||||||
cliproxy
|
cliproxy
|
||||||
|
/server
|
||||||
*.exe
|
*.exe
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
version: 2
|
||||||
|
|
||||||
builds:
|
builds:
|
||||||
- id: "cli-proxy-api-plus"
|
- id: "cli-proxy-api-plus"
|
||||||
env:
|
env:
|
||||||
@@ -6,6 +8,7 @@ builds:
|
|||||||
- linux
|
- linux
|
||||||
- windows
|
- windows
|
||||||
- darwin
|
- darwin
|
||||||
|
- freebsd
|
||||||
goarch:
|
goarch:
|
||||||
- amd64
|
- amd64
|
||||||
- arm64
|
- arm64
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
# CLIProxyAPI Plus
|
# CLIProxyAPI Plus
|
||||||
|
|
||||||
[English](README.md) | 中文
|
[English](README.md) | 中文 | [日本語](README_JA.md)
|
||||||
|
|
||||||
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
||||||
|
|
||||||
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
||||||
|
|
||||||
该 Plus 版本的主线功能与主线功能强制同步。
|
|
||||||
|
|
||||||
## 贡献
|
## 贡献
|
||||||
|
|
||||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||||
@@ -16,4 +14,4 @@
|
|||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||||
|
|||||||
187
README_JA.md
Normal file
187
README_JA.md
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
# CLI Proxy API
|
||||||
|
|
||||||
|
[English](README.md) | [中文](README_CN.md) | 日本語
|
||||||
|
|
||||||
|
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
|
||||||
|
|
||||||
|
OAuth経由でOpenAI Codex(GPTモデル)およびClaude Codeもサポートしています。
|
||||||
|
|
||||||
|
ローカルまたはマルチアカウントのCLIアクセスを、OpenAI(Responses含む)/Gemini/Claude互換のクライアントやSDKで利用できます。
|
||||||
|
|
||||||
|
## スポンサー
|
||||||
|
|
||||||
|
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||||
|
|
||||||
|
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
|
||||||
|
|
||||||
|
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7および(GLM-5はProユーザーのみ利用可能)モデルを10以上の人気AIコーディングツール(Claude Code、Cline、Roo Codeなど)で利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
|
||||||
|
|
||||||
|
GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||||
|
<td>PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||||
|
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||||
|
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割(90% OFF)</b> という驚異的な価格でご利用いただけます!</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
## 概要
|
||||||
|
|
||||||
|
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
|
||||||
|
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
|
||||||
|
- OAuthログインによるClaude Codeサポート
|
||||||
|
- OAuthログインによるQwen Codeサポート
|
||||||
|
- OAuthログインによるiFlowサポート
|
||||||
|
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
|
||||||
|
- ストリーミングおよび非ストリーミングレスポンス
|
||||||
|
- 関数呼び出し/ツールのサポート
|
||||||
|
- マルチモーダル入力サポート(テキストと画像)
|
||||||
|
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||||
|
- シンプルなCLI認証フロー(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
||||||
|
- Generative Language APIキーのサポート
|
||||||
|
- AI Studioビルドのマルチアカウント負荷分散
|
||||||
|
- Gemini CLIのマルチアカウント負荷分散
|
||||||
|
- Claude Codeのマルチアカウント負荷分散
|
||||||
|
- Qwen Codeのマルチアカウント負荷分散
|
||||||
|
- iFlowのマルチアカウント負荷分散
|
||||||
|
- OpenAI Codexのマルチアカウント負荷分散
|
||||||
|
- 設定によるOpenAI互換アップストリームプロバイダー(例:OpenRouter)
|
||||||
|
- プロキシ埋め込み用の再利用可能なGo SDK(`docs/sdk-usage.md`を参照)
|
||||||
|
|
||||||
|
## はじめに
|
||||||
|
|
||||||
|
CLIProxyAPIガイド:[https://help.router-for.me/](https://help.router-for.me/)
|
||||||
|
|
||||||
|
## 管理API
|
||||||
|
|
||||||
|
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
|
||||||
|
|
||||||
|
## Amp CLIサポート
|
||||||
|
|
||||||
|
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます:
|
||||||
|
|
||||||
|
- Ampの APIパターン用のプロバイダールートエイリアス(`/api/provider/{provider}/v1...`)
|
||||||
|
- OAuth認証およびアカウント機能用の管理プロキシ
|
||||||
|
- 自動ルーティングによるスマートモデルフォールバック
|
||||||
|
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
||||||
|
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
||||||
|
|
||||||
|
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||||
|
|
||||||
|
## SDKドキュメント
|
||||||
|
|
||||||
|
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
|
||||||
|
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
|
||||||
|
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
|
||||||
|
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
|
||||||
|
- カスタムプロバイダーの例:`examples/custom-provider`
|
||||||
|
|
||||||
|
## コントリビューション
|
||||||
|
|
||||||
|
コントリビューションを歓迎します!お気軽にPull Requestを送ってください。
|
||||||
|
|
||||||
|
1. リポジトリをフォーク
|
||||||
|
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`)
|
||||||
|
3. 変更をコミット(`git commit -m 'Add some amazing feature'`)
|
||||||
|
4. ブランチにプッシュ(`git push origin feature/amazing-feature`)
|
||||||
|
5. Pull Requestを作成
|
||||||
|
|
||||||
|
## 関連プロジェクト
|
||||||
|
|
||||||
|
CLIProxyAPIをベースにした以下のプロジェクトがあります:
|
||||||
|
|
||||||
|
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
||||||
|
|
||||||
|
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
|
||||||
|
|
||||||
|
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
||||||
|
|
||||||
|
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
|
||||||
|
|
||||||
|
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
||||||
|
|
||||||
|
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデル(Gemini、Codex、Antigravity)を即座に切り替えるCLIラッパー - APIキー不要
|
||||||
|
|
||||||
|
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
||||||
|
|
||||||
|
CLIProxyAPI管理用のmacOSネイティブGUI:OAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
|
||||||
|
|
||||||
|
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||||
|
|
||||||
|
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
|
||||||
|
|
||||||
|
### [CodMate](https://github.com/loocor/CodMate)
|
||||||
|
|
||||||
|
CLI AIセッション(Codex、Claude Code、Gemini CLI)を管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
|
||||||
|
|
||||||
|
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||||
|
|
||||||
|
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
|
||||||
|
|
||||||
|
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||||
|
|
||||||
|
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
|
||||||
|
|
||||||
|
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
|
||||||
|
|
||||||
|
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
|
||||||
|
|
||||||
|
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||||
|
|
||||||
|
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
|
||||||
|
|
||||||
|
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||||
|
|
||||||
|
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替え(Main / Plus)、自動ダウンロードおよび自動更新に対応
|
||||||
|
|
||||||
|
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||||
|
|
||||||
|
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
|
||||||
|
|
||||||
|
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
|
||||||
|
|
||||||
|
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
|
||||||
|
|
||||||
|
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||||
|
|
||||||
|
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
|
||||||
|
|
||||||
|
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
|
||||||
|
|
||||||
|
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LAN(ローカルエリアネットワーク)を介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||||
|
|
||||||
|
## その他の選択肢
|
||||||
|
|
||||||
|
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです:
|
||||||
|
|
||||||
|
### [9Router](https://github.com/decolua/9router)
|
||||||
|
|
||||||
|
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換(OpenAI/Claude/Gemini/Ollama)、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツール(Cursor、Claude Code、Cline、RooCode)のサポートをゼロから構築 - APIキー不要
|
||||||
|
|
||||||
|
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
|
||||||
|
|
||||||
|
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
|
||||||
|
|
||||||
|
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:スマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||||
|
|
||||||
|
## ライセンス
|
||||||
|
|
||||||
|
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。
|
||||||
BIN
assets/bmoplus.png
Normal file
BIN
assets/bmoplus.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
275
cmd/fetch_antigravity_models/main.go
Normal file
275
cmd/fetch_antigravity_models/main.go
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
// Command fetch_antigravity_models connects to the Antigravity API using the
|
||||||
|
// stored auth credentials and saves the dynamically fetched model list to a
|
||||||
|
// JSON file for inspection or offline use.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// go run ./cmd/fetch_antigravity_models [flags]
|
||||||
|
//
|
||||||
|
// Flags:
|
||||||
|
//
|
||||||
|
// --auths-dir <path> Directory containing auth JSON files (default: "auths")
|
||||||
|
// --output <path> Output JSON file path (default: "antigravity_models.json")
|
||||||
|
// --pretty Pretty-print the output JSON (default: true)
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
|
||||||
|
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
|
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||||
|
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
logging.SetupBaseLogger()
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelOutput wraps the fetched model list with fetch metadata.
|
||||||
|
type modelOutput struct {
|
||||||
|
Models []modelEntry `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelEntry contains only the fields we want to keep for static model definitions.
|
||||||
|
type modelEntry struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
ContextLength int `json:"context_length,omitempty"`
|
||||||
|
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var authsDir string
|
||||||
|
var outputPath string
|
||||||
|
var pretty bool
|
||||||
|
|
||||||
|
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
|
||||||
|
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
|
||||||
|
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Resolve relative paths against the working directory.
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(authsDir) {
|
||||||
|
authsDir = filepath.Join(wd, authsDir)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(outputPath) {
|
||||||
|
outputPath = filepath.Join(wd, outputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Scanning auth files in: %s\n", authsDir)
|
||||||
|
|
||||||
|
// Load all auth records from the directory.
|
||||||
|
fileStore := sdkauth.NewFileTokenStore()
|
||||||
|
fileStore.SetBaseDir(authsDir)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
auths, err := fileStore.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if len(auths) == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first enabled antigravity auth.
|
||||||
|
var chosen *coreauth.Auth
|
||||||
|
for _, a := range auths {
|
||||||
|
if a == nil || a.Disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
|
||||||
|
chosen = a
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if chosen == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
|
||||||
|
|
||||||
|
// Fetch models from the upstream Antigravity API.
|
||||||
|
fmt.Println("Fetching Antigravity model list from upstream...")
|
||||||
|
|
||||||
|
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
models := fetchModels(fetchCtx, chosen)
|
||||||
|
if len(models) == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Fetched %d models.\n", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the output payload.
|
||||||
|
out := modelOutput{
|
||||||
|
Models: models,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal to JSON.
|
||||||
|
var raw []byte
|
||||||
|
if pretty {
|
||||||
|
raw, err = json.MarshalIndent(out, "", " ")
|
||||||
|
} else {
|
||||||
|
raw, err = json.Marshal(out)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Model list saved to: %s\n", outputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
|
||||||
|
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||||
|
if accessToken == "" {
|
||||||
|
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
|
||||||
|
|
||||||
|
for _, baseURL := range baseURLs {
|
||||||
|
modelsURL := baseURL + antigravityModelsPath
|
||||||
|
|
||||||
|
var payload []byte
|
||||||
|
if auth != nil && auth.Metadata != nil {
|
||||||
|
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||||
|
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
payload = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
|
||||||
|
if errReq != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
httpReq.Close = true
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
|
||||||
|
|
||||||
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
|
||||||
|
httpClient.Transport = transport
|
||||||
|
}
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
httpResp.Body.Close()
|
||||||
|
if errRead != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := gjson.GetBytes(bodyBytes, "models")
|
||||||
|
if !result.Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var models []modelEntry
|
||||||
|
|
||||||
|
for originalName, modelData := range result.Map() {
|
||||||
|
modelID := strings.TrimSpace(originalName)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Skip internal/experimental models
|
||||||
|
switch modelID {
|
||||||
|
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
displayName := modelData.Get("displayName").String()
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = modelID
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := modelEntry{
|
||||||
|
ID: modelID,
|
||||||
|
Object: "model",
|
||||||
|
OwnedBy: "antigravity",
|
||||||
|
Type: "antigravity",
|
||||||
|
DisplayName: displayName,
|
||||||
|
Name: modelID,
|
||||||
|
Description: displayName,
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||||
|
entry.ContextLength = int(maxTok)
|
||||||
|
}
|
||||||
|
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||||
|
entry.MaxCompletionTokens = int(maxOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
models = append(models, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func metaStringValue(m map[string]interface{}, key string) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v, ok := m[key]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
return val
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
20
cmd/mcpdebug/main.go
Normal file
20
cmd/mcpdebug/main.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Encode MCP result with empty execId
|
||||||
|
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
|
||||||
|
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
|
||||||
|
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
|
||||||
|
|
||||||
|
// Write to file for analysis
|
||||||
|
os.WriteFile("mcp_result.bin", resultBytes)
|
||||||
|
fmt.Println("Wrote mcp_result.bin")
|
||||||
|
}
|
||||||
32
cmd/protocheck/main.go
Normal file
32
cmd/protocheck/main.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ecm := cursorproto.NewMsg("ExecClientMessage")
|
||||||
|
|
||||||
|
// Try different field names
|
||||||
|
names := []string{
|
||||||
|
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
|
||||||
|
"shell_result", "shellResult",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range names {
|
||||||
|
fd := ecm.Descriptor().Fields().ByName(name)
|
||||||
|
if fd != nil {
|
||||||
|
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Field %q NOT FOUND\n", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List all fields
|
||||||
|
fmt.Println("\nAll fields in ExecClientMessage:")
|
||||||
|
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
|
||||||
|
f := ecm.Descriptor().Fields().Get(i)
|
||||||
|
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||||
@@ -78,10 +79,13 @@ func main() {
|
|||||||
var kiloLogin bool
|
var kiloLogin bool
|
||||||
var iflowLogin bool
|
var iflowLogin bool
|
||||||
var iflowCookie bool
|
var iflowCookie bool
|
||||||
|
var gitlabLogin bool
|
||||||
|
var gitlabTokenLogin bool
|
||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
var oauthCallbackPort int
|
var oauthCallbackPort int
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
var kimiLogin bool
|
var kimiLogin bool
|
||||||
|
var cursorLogin bool
|
||||||
var kiroLogin bool
|
var kiroLogin bool
|
||||||
var kiroGoogleLogin bool
|
var kiroGoogleLogin bool
|
||||||
var kiroAWSLogin bool
|
var kiroAWSLogin bool
|
||||||
@@ -92,6 +96,7 @@ func main() {
|
|||||||
var kiroIDCRegion string
|
var kiroIDCRegion string
|
||||||
var kiroIDCFlow string
|
var kiroIDCFlow string
|
||||||
var githubCopilotLogin bool
|
var githubCopilotLogin bool
|
||||||
|
var codeBuddyLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
var configPath string
|
var configPath string
|
||||||
@@ -100,6 +105,7 @@ func main() {
|
|||||||
var standalone bool
|
var standalone bool
|
||||||
var noIncognito bool
|
var noIncognito bool
|
||||||
var useIncognito bool
|
var useIncognito bool
|
||||||
|
var localModel bool
|
||||||
|
|
||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
@@ -110,12 +116,15 @@ func main() {
|
|||||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||||
|
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
|
||||||
|
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
|
||||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||||
|
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
|
||||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||||
@@ -126,12 +135,14 @@ func main() {
|
|||||||
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
||||||
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
||||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||||
|
flag.BoolVar(&codeBuddyLogin, "codebuddy-login", false, "Login to CodeBuddy using browser OAuth flow")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
flag.StringVar(&password, "password", "", "")
|
flag.StringVar(&password, "password", "", "")
|
||||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||||
|
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
||||||
|
|
||||||
flag.CommandLine.Usage = func() {
|
flag.CommandLine.Usage = func() {
|
||||||
out := flag.CommandLine.Output()
|
out := flag.CommandLine.Output()
|
||||||
@@ -509,6 +520,9 @@ func main() {
|
|||||||
} else if githubCopilotLogin {
|
} else if githubCopilotLogin {
|
||||||
// Handle GitHub Copilot login
|
// Handle GitHub Copilot login
|
||||||
cmd.DoGitHubCopilotLogin(cfg, options)
|
cmd.DoGitHubCopilotLogin(cfg, options)
|
||||||
|
} else if codeBuddyLogin {
|
||||||
|
// Handle CodeBuddy login
|
||||||
|
cmd.DoCodeBuddyLogin(cfg, options)
|
||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
// Handle Codex login
|
// Handle Codex login
|
||||||
cmd.DoCodexLogin(cfg, options)
|
cmd.DoCodexLogin(cfg, options)
|
||||||
@@ -526,8 +540,14 @@ func main() {
|
|||||||
cmd.DoIFlowLogin(cfg, options)
|
cmd.DoIFlowLogin(cfg, options)
|
||||||
} else if iflowCookie {
|
} else if iflowCookie {
|
||||||
cmd.DoIFlowCookieAuth(cfg, options)
|
cmd.DoIFlowCookieAuth(cfg, options)
|
||||||
|
} else if gitlabLogin {
|
||||||
|
cmd.DoGitLabLogin(cfg, options)
|
||||||
|
} else if gitlabTokenLogin {
|
||||||
|
cmd.DoGitLabTokenLogin(cfg, options)
|
||||||
} else if kimiLogin {
|
} else if kimiLogin {
|
||||||
cmd.DoKimiLogin(cfg, options)
|
cmd.DoKimiLogin(cfg, options)
|
||||||
|
} else if cursorLogin {
|
||||||
|
cmd.DoCursorLogin(cfg, options)
|
||||||
} else if kiroLogin {
|
} else if kiroLogin {
|
||||||
// For Kiro auth, default to incognito mode for multi-account support
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
// Users can explicitly override with --no-incognito
|
// Users can explicitly override with --no-incognito
|
||||||
@@ -569,10 +589,16 @@ func main() {
|
|||||||
cmd.WaitForCloudDeploy()
|
cmd.WaitForCloudDeploy()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if localModel && (!tuiMode || standalone) {
|
||||||
|
log.Info("Local model mode: using embedded model catalog, remote model updates disabled")
|
||||||
|
}
|
||||||
if tuiMode {
|
if tuiMode {
|
||||||
if standalone {
|
if standalone {
|
||||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
if !localModel {
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
|
}
|
||||||
hook := tui.NewLogHook(2000)
|
hook := tui.NewLogHook(2000)
|
||||||
hook.SetFormatter(&logging.LogFormatter{})
|
hook.SetFormatter(&logging.LogFormatter{})
|
||||||
log.AddHook(hook)
|
log.AddHook(hook)
|
||||||
@@ -643,15 +669,18 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Start the main proxy service
|
// Start the main proxy service
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
if !localModel {
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.AuthDir != "" {
|
if cfg.AuthDir != "" {
|
||||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||||
defer kiro.StopGlobalRefreshManager()
|
defer kiro.StopGlobalRefreshManager()
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.StartService(cfg, configFilePath, password)
|
cmd.StartService(cfg, configFilePath, password)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ remote-management:
|
|||||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||||
disable-control-panel: false
|
disable-control-panel: false
|
||||||
|
|
||||||
|
# Disable automatic periodic background updates of the management panel from GitHub (default: false).
|
||||||
|
# When enabled, the panel is only downloaded on first access if missing, and never auto-updated afterward.
|
||||||
|
# disable-auto-update-panel: false
|
||||||
|
|
||||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||||
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
||||||
|
|
||||||
@@ -68,7 +72,8 @@ error-logs-max-files: 10
|
|||||||
usage-statistics-enabled: false
|
usage-statistics-enabled: false
|
||||||
|
|
||||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||||
proxy-url: ''
|
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
|
||||||
|
proxy-url: ""
|
||||||
|
|
||||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||||
force-model-prefix: false
|
force-model-prefix: false
|
||||||
@@ -115,6 +120,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080"
|
# proxy-url: "socks5://proxy.example.com:1080"
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "gemini-2.5-flash" # upstream model name
|
# - name: "gemini-2.5-flash" # upstream model name
|
||||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||||
@@ -133,6 +139,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "gpt-5-codex" # upstream model name
|
# - name: "gpt-5-codex" # upstream model name
|
||||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||||
@@ -151,6 +158,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||||
@@ -171,12 +179,27 @@ nonstream-keepalive-interval: 0
|
|||||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||||
|
|
||||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||||
# These are used as fallbacks when the client does not send its own headers.
|
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
||||||
|
# when the client omits them, while OS/arch remain runtime-derived. When
|
||||||
|
# stabilize-device-profile is enabled, OS/arch stay pinned to the baseline values below,
|
||||||
|
# while user-agent/package-version/runtime-version seed a software fingerprint that can
|
||||||
|
# still upgrade to newer official Claude client versions.
|
||||||
# claude-header-defaults:
|
# claude-header-defaults:
|
||||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||||
# package-version: "0.74.0"
|
# package-version: "0.74.0"
|
||||||
# runtime-version: "v24.3.0"
|
# runtime-version: "v24.3.0"
|
||||||
|
# os: "MacOS"
|
||||||
|
# arch: "arm64"
|
||||||
# timeout: "600"
|
# timeout: "600"
|
||||||
|
# stabilize-device-profile: false # optional, default false; set true to enable per-auth/API-key fingerprint pinning
|
||||||
|
|
||||||
|
# Default headers for Codex OAuth model requests.
|
||||||
|
# These are used only for file-backed/OAuth Codex requests when the client
|
||||||
|
# does not send the header. `user-agent` applies to HTTP and websocket requests;
|
||||||
|
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
|
||||||
|
# codex-header-defaults:
|
||||||
|
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
|
||||||
|
# beta-features: "multi_agent"
|
||||||
|
|
||||||
# Kiro (AWS CodeWhisperer) configuration
|
# Kiro (AWS CodeWhisperer) configuration
|
||||||
# Note: Kiro API currently only operates in us-east-1 region
|
# Note: Kiro API currently only operates in us-east-1 region
|
||||||
@@ -215,10 +238,13 @@ nonstream-keepalive-interval: 0
|
|||||||
# api-key-entries:
|
# api-key-entries:
|
||||||
# - api-key: "sk-or-v1-...b780"
|
# - api-key: "sk-or-v1-...b780"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||||
# models: # The models supported by the provider.
|
# models: # The models supported by the provider.
|
||||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||||
# alias: "kimi-k2" # The alias used in the API.
|
# alias: "kimi-k2" # The alias used in the API.
|
||||||
|
# thinking: # optional: omit to default to levels ["low","medium","high"]
|
||||||
|
# levels: ["low", "medium", "high"]
|
||||||
# # You may repeat the same alias to build an internal model pool.
|
# # You may repeat the same alias to build an internal model pool.
|
||||||
# # The client still sees only one alias in the model list.
|
# # The client still sees only one alias in the model list.
|
||||||
# # Requests to that alias will round-robin across the upstream names below,
|
# # Requests to that alias will round-robin across the upstream names below,
|
||||||
@@ -231,12 +257,13 @@ nonstream-keepalive-interval: 0
|
|||||||
# - name: "kimi-k2.5"
|
# - name: "kimi-k2.5"
|
||||||
# alias: "claude-opus-4.66"
|
# alias: "claude-opus-4.66"
|
||||||
|
|
||||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
# Vertex API keys (Vertex-compatible endpoints, base-url is optional)
|
||||||
# vertex-api-key:
|
# vertex-api-key:
|
||||||
# - api-key: "vk-123..." # x-goog-api-key header
|
# - api-key: "vk-123..." # x-goog-api-key header
|
||||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
# base-url: "https://example.com/api" # optional, e.g. https://zenmux.ai/api; falls back to Google Vertex when omitted
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# models: # optional: map aliases to upstream model names
|
# models: # optional: map aliases to upstream model names
|
||||||
|
|||||||
115
docs/gitlab-duo.md
Normal file
115
docs/gitlab-duo.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# GitLab Duo guide
|
||||||
|
|
||||||
|
CLIProxyAPI can now use GitLab Duo as a first-class provider instead of treating it as a plain text wrapper.
|
||||||
|
|
||||||
|
It supports:
|
||||||
|
|
||||||
|
- OAuth login
|
||||||
|
- personal access token login
|
||||||
|
- automatic refresh of GitLab `direct_access` metadata
|
||||||
|
- dynamic model discovery from GitLab metadata
|
||||||
|
- native GitLab AI gateway routing for Anthropic and OpenAI/Codex managed models
|
||||||
|
- Claude-compatible and OpenAI-compatible downstream APIs
|
||||||
|
|
||||||
|
## What this means
|
||||||
|
|
||||||
|
If GitLab Duo returns an Anthropic-managed model, CLIProxyAPI routes requests through the GitLab AI gateway Anthropic proxy and uses the existing Claude executor path.
|
||||||
|
|
||||||
|
If GitLab Duo returns an OpenAI-managed model, CLIProxyAPI routes requests through the GitLab AI gateway OpenAI proxy and uses the existing Codex/OpenAI executor path.
|
||||||
|
|
||||||
|
That gives GitLab Duo much closer runtime behavior to the built-in `codex` provider:
|
||||||
|
|
||||||
|
- Claude-compatible clients can use GitLab Duo models through `/v1/messages`
|
||||||
|
- OpenAI-compatible clients can use GitLab Duo models through `/v1/chat/completions`
|
||||||
|
- OpenAI Responses clients can use GitLab Duo models through `/v1/responses`
|
||||||
|
|
||||||
|
The model list is not hardcoded. CLIProxyAPI reads the current model metadata from GitLab `direct_access` and registers:
|
||||||
|
|
||||||
|
- a stable alias: `gitlab-duo`
|
||||||
|
- any discovered managed model names, such as `claude-sonnet-4-5` or `gpt-5-codex`
|
||||||
|
|
||||||
|
## Login
|
||||||
|
|
||||||
|
OAuth login:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-login
|
||||||
|
```
|
||||||
|
|
||||||
|
PAT login:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-token-login
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also provide inputs through environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GITLAB_BASE_URL=https://gitlab.com
|
||||||
|
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||||
|
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||||
|
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- OAuth requires a GitLab OAuth application.
|
||||||
|
- PAT login requires a personal access token that can call the GitLab APIs used by Duo. In practice, `api` scope is the safe baseline.
|
||||||
|
- Self-managed GitLab instances are supported through `GITLAB_BASE_URL`.
|
||||||
|
|
||||||
|
## Using the models
|
||||||
|
|
||||||
|
After login, start CLIProxyAPI normally and point your client at the local proxy.
|
||||||
|
|
||||||
|
You can select:
|
||||||
|
|
||||||
|
- `gitlab-duo` to use the current Duo-managed model for that account
|
||||||
|
- the discovered provider model name if you want to pin it explicitly
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "gitlab-duo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
If the GitLab account is currently mapped to an Anthropic model, Claude-compatible clients can use the same account through the Claude handler path. If the account is currently mapped to an OpenAI/Codex model, OpenAI-compatible clients can use `/v1/chat/completions` or `/v1/responses`.
|
||||||
|
|
||||||
|
## How model freshness works
|
||||||
|
|
||||||
|
CLIProxyAPI does not ship a fixed GitLab Duo model catalog.
|
||||||
|
|
||||||
|
Instead, it refreshes GitLab `direct_access` metadata and uses the returned `model_details` and any discovered model list entries to keep the local registry aligned with the current GitLab-managed model assignment.
|
||||||
|
|
||||||
|
This matches GitLab's current public contract better than hardcoding model names.
|
||||||
|
|
||||||
|
## Current scope
|
||||||
|
|
||||||
|
The GitLab Duo provider now has:
|
||||||
|
|
||||||
|
- OAuth and PAT auth flows
|
||||||
|
- runtime refresh of Duo gateway credentials
|
||||||
|
- native Anthropic gateway routing
|
||||||
|
- native OpenAI/Codex gateway routing
|
||||||
|
- handler-level smoke tests for Claude-compatible and OpenAI-compatible paths
|
||||||
|
|
||||||
|
Still out of scope today:
|
||||||
|
|
||||||
|
- websocket or session-specific parity beyond the current HTTP APIs
|
||||||
|
- GitLab-specific IDE features that are not exposed through the public gateway contract
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||||
|
- GitLab Agent Assistant and managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||||
|
- GitLab Duo model selection: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||||
115
docs/gitlab-duo_CN.md
Normal file
115
docs/gitlab-duo_CN.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# GitLab Duo 使用说明
|
||||||
|
|
||||||
|
CLIProxyAPI 现在可以把 GitLab Duo 当作一等 Provider 来使用,而不是仅仅把它当成简单的文本补全封装。
|
||||||
|
|
||||||
|
当前支持:
|
||||||
|
|
||||||
|
- OAuth 登录
|
||||||
|
- personal access token 登录
|
||||||
|
- 自动刷新 GitLab `direct_access` 元数据
|
||||||
|
- 根据 GitLab 返回的元数据动态发现模型
|
||||||
|
- 针对 Anthropic 和 OpenAI/Codex 托管模型的 GitLab AI gateway 原生路由
|
||||||
|
- Claude 兼容与 OpenAI 兼容下游 API
|
||||||
|
|
||||||
|
## 这意味着什么
|
||||||
|
|
||||||
|
如果 GitLab Duo 返回的是 Anthropic 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 Anthropic 代理转发,并复用现有的 Claude executor 路径。
|
||||||
|
|
||||||
|
如果 GitLab Duo 返回的是 OpenAI 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 OpenAI 代理转发,并复用现有的 Codex/OpenAI executor 路径。
|
||||||
|
|
||||||
|
这让 GitLab Duo 的运行时行为更接近内置的 `codex` Provider:
|
||||||
|
|
||||||
|
- Claude 兼容客户端可以通过 `/v1/messages` 使用 GitLab Duo 模型
|
||||||
|
- OpenAI 兼容客户端可以通过 `/v1/chat/completions` 使用 GitLab Duo 模型
|
||||||
|
- OpenAI Responses 客户端可以通过 `/v1/responses` 使用 GitLab Duo 模型
|
||||||
|
|
||||||
|
模型列表不是硬编码的。CLIProxyAPI 会从 GitLab `direct_access` 中读取当前模型元数据,并注册:
|
||||||
|
|
||||||
|
- 一个稳定别名:`gitlab-duo`
|
||||||
|
- GitLab 当前发现到的托管模型名,例如 `claude-sonnet-4-5` 或 `gpt-5-codex`
|
||||||
|
|
||||||
|
## 登录
|
||||||
|
|
||||||
|
OAuth 登录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-login
|
||||||
|
```
|
||||||
|
|
||||||
|
PAT 登录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-token-login
|
||||||
|
```
|
||||||
|
|
||||||
|
也可以通过环境变量提供输入:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GITLAB_BASE_URL=https://gitlab.com
|
||||||
|
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||||
|
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||||
|
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||||
|
```
|
||||||
|
|
||||||
|
说明:
|
||||||
|
|
||||||
|
- OAuth 方式需要一个 GitLab OAuth application。
|
||||||
|
- PAT 登录需要一个能够调用 GitLab Duo 相关 API 的 personal access token。实践上,`api` scope 是最稳妥的基线。
|
||||||
|
- 自建 GitLab 实例可以通过 `GITLAB_BASE_URL` 接入。
|
||||||
|
|
||||||
|
## 如何使用模型
|
||||||
|
|
||||||
|
登录完成后,正常启动 CLIProxyAPI,并让客户端连接到本地代理。
|
||||||
|
|
||||||
|
你可以选择:
|
||||||
|
|
||||||
|
- `gitlab-duo`,始终使用该账号当前的 Duo 托管模型
|
||||||
|
- GitLab 当前发现到的 provider 模型名,如果你想显式固定模型
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "gitlab-duo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
如果该 GitLab 账号当前绑定的是 Anthropic 模型,Claude 兼容客户端可以通过 Claude handler 路径直接使用它。如果当前绑定的是 OpenAI/Codex 模型,OpenAI 兼容客户端可以通过 `/v1/chat/completions` 或 `/v1/responses` 使用它。
|
||||||
|
|
||||||
|
## 模型如何保持最新
|
||||||
|
|
||||||
|
CLIProxyAPI 不内置固定的 GitLab Duo 模型清单。
|
||||||
|
|
||||||
|
它会刷新 GitLab `direct_access` 元数据,并使用返回的 `model_details` 以及可能存在的模型列表字段,让本地 registry 尽量与 GitLab 当前分配的托管模型保持一致。
|
||||||
|
|
||||||
|
这比硬编码模型名更符合 GitLab 当前公开 API 的实际契约。
|
||||||
|
|
||||||
|
## 当前覆盖范围
|
||||||
|
|
||||||
|
GitLab Duo Provider 目前已经具备:
|
||||||
|
|
||||||
|
- OAuth 和 PAT 登录流程
|
||||||
|
- Duo gateway 凭据的运行时刷新
|
||||||
|
- Anthropic gateway 原生路由
|
||||||
|
- OpenAI/Codex gateway 原生路由
|
||||||
|
- Claude 兼容和 OpenAI 兼容路径的 handler 级 smoke 测试
|
||||||
|
|
||||||
|
当前仍未覆盖:
|
||||||
|
|
||||||
|
- websocket 或 session 级别的完全对齐
|
||||||
|
- GitLab 公开 gateway 契约之外的 IDE 专有能力
|
||||||
|
|
||||||
|
## 参考资料
|
||||||
|
|
||||||
|
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||||
|
- GitLab Agent Assistant 与 managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||||
|
- GitLab Duo 模型选择: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||||
@@ -52,11 +52,11 @@ func init() {
|
|||||||
sdktr.Register(fOpenAI, fMyProv,
|
sdktr.Register(fOpenAI, fMyProv,
|
||||||
func(model string, raw []byte, stream bool) []byte { return raw },
|
func(model string, raw []byte, stream bool) []byte { return raw },
|
||||||
sdktr.ResponseTransform{
|
sdktr.ResponseTransform{
|
||||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
|
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) [][]byte {
|
||||||
return []string{string(raw)}
|
return [][]byte{raw}
|
||||||
},
|
},
|
||||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
|
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []byte {
|
||||||
return string(raw)
|
return raw
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
278
gitlab-duo-codex-parity-plan.md
Normal file
278
gitlab-duo-codex-parity-plan.md
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
# Plan: GitLab Duo Codex Parity
|
||||||
|
|
||||||
|
**Generated**: 2026-03-10
|
||||||
|
**Estimated Complexity**: High
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Bring GitLab Duo support from the current "auth + basic executor" stage to the same practical level as `codex` inside `CLIProxyAPI`: a user logs in once, points external clients such as Claude Code at `CLIProxyAPI`, selects GitLab Duo-backed models, and gets stable streaming, multi-turn behavior, tool calling compatibility, and predictable model routing without manual provider-specific workarounds.
|
||||||
|
|
||||||
|
The core architectural shift is to stop treating GitLab Duo as only two REST wrappers (`/api/v4/chat/completions` and `/api/v4/code_suggestions/completions`) and instead use GitLab's `direct_access` contract as the primary runtime entrypoint wherever possible. Official GitLab docs confirm that `direct_access` returns AI gateway connection details, headers, token, and expiry; that contract is the closest path to codex-like provider behavior.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- Official GitLab Duo API references confirmed during implementation:
|
||||||
|
- `POST /api/v4/code_suggestions/direct_access`
|
||||||
|
- `POST /api/v4/code_suggestions/completions`
|
||||||
|
- `POST /api/v4/chat/completions`
|
||||||
|
- Access to at least one real GitLab Duo account for manual verification.
|
||||||
|
- One downstream client target for acceptance testing:
|
||||||
|
- Claude Code against Claude-compatible endpoint
|
||||||
|
- OpenAI-compatible client against `/v1/chat/completions` and `/v1/responses`
|
||||||
|
- Existing PR branch as starting point:
|
||||||
|
- `feat/gitlab-duo-auth`
|
||||||
|
- PR [#2028](https://github.com/router-for-me/CLIProxyAPI/pull/2028)
|
||||||
|
|
||||||
|
## Definition Of Done
|
||||||
|
- GitLab Duo models can be used via `CLIProxyAPI` from the same client surfaces that already work for `codex`.
|
||||||
|
- Upstream streaming is real passthrough or faithful chunked forwarding, not synthetic whole-response replay.
|
||||||
|
- Tool/function calling survives translation layers without dropping fields or corrupting names.
|
||||||
|
- Multi-turn and session semantics are stable across `chat/completions`, `responses`, and Claude-compatible routes.
|
||||||
|
- Model exposure stays current from GitLab metadata or gateway discovery without hardcoded stale model tables.
|
||||||
|
- `go test ./...` stays green and at least one real manual end-to-end client flow is documented.
|
||||||
|
|
||||||
|
## Sprint 1: Contract And Gap Closure
|
||||||
|
**Goal**: Replace assumptions with a hard compatibility contract between current `codex` behavior and what GitLab Duo can actually support.
|
||||||
|
|
||||||
|
**Demo/Validation**:
|
||||||
|
- Written matrix showing `codex` features vs current GitLab Duo behavior.
|
||||||
|
- One checked-in developer note or test fixture for real GitLab Duo payload examples.
|
||||||
|
|
||||||
|
### Task 1.1: Freeze Codex Parity Checklist
|
||||||
|
- **Location**: [internal/runtime/executor/codex_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_executor.go), [internal/runtime/executor/codex_websockets_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_websockets_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go)
|
||||||
|
- **Description**: Produce a concrete feature matrix for `codex`: HTTP execute, SSE execute, `/v1/responses`, websocket downstream path, tool calling, request IDs, session close semantics, and model registration behavior.
|
||||||
|
- **Dependencies**: None
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- A checklist exists in repo docs or issue notes.
|
||||||
|
- Each capability is marked `required`, `optional`, or `not possible` for GitLab Duo.
|
||||||
|
- **Validation**:
|
||||||
|
- Review against current `codex` code paths.
|
||||||
|
|
||||||
|
### Task 1.2: Lock GitLab Duo Runtime Contract
|
||||||
|
- **Location**: [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||||
|
- **Description**: Validate the exact upstream contract we can rely on:
|
||||||
|
- `direct_access` fields and refresh cadence
|
||||||
|
- whether AI gateway path is usable directly
|
||||||
|
- when `chat/completions` is available vs when fallback is required
|
||||||
|
- what streaming shape is returned by `code_suggestions/completions?stream=true`
|
||||||
|
- **Dependencies**: Task 1.1
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- GitLab transport decision is explicit: `gateway-first`, `REST-first`, or `hybrid`.
|
||||||
|
- Unknown areas are isolated behind feature flags, not spread across executor logic.
|
||||||
|
- **Validation**:
|
||||||
|
- Official docs + captured real responses from a Duo account.
|
||||||
|
|
||||||
|
### Task 1.3: Define Client-Facing Compatibility Targets
|
||||||
|
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md), [gitlab-duo-codex-parity-plan.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/gitlab-duo-codex-parity-plan.md)
|
||||||
|
- **Description**: Define exactly which external flows must work to call GitLab Duo support "like codex".
|
||||||
|
- **Dependencies**: Task 1.2
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- Required surfaces are listed:
|
||||||
|
- Claude-compatible route
|
||||||
|
- OpenAI `chat/completions`
|
||||||
|
- OpenAI `responses`
|
||||||
|
- optional downstream websocket path
|
||||||
|
- Non-goals are explicit if GitLab upstream cannot support them.
|
||||||
|
- **Validation**:
|
||||||
|
- Maintainer review of stated scope.
|
||||||
|
|
||||||
|
## Sprint 2: Primary Transport Parity
|
||||||
|
**Goal**: Move GitLab Duo execution onto a transport that supports codex-like runtime behavior.
|
||||||
|
|
||||||
|
**Demo/Validation**:
|
||||||
|
- A GitLab Duo model works over real streaming through `/v1/chat/completions`.
|
||||||
|
- No synthetic "collect full body then fake stream" path remains on the primary flow.
|
||||||
|
|
||||||
|
### Task 2.1: Refactor GitLab Executor Into Strategy Layers
|
||||||
|
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||||
|
- **Description**: Split current executor into explicit strategies:
|
||||||
|
- auth refresh/direct access refresh
|
||||||
|
- gateway transport
|
||||||
|
- GitLab REST fallback transport
|
||||||
|
- downstream translation helpers
|
||||||
|
- **Dependencies**: Sprint 1
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- Executor no longer mixes discovery, refresh, fallback selection, and response synthesis in one path.
|
||||||
|
- Transport choice is testable in isolation.
|
||||||
|
- **Validation**:
|
||||||
|
- Unit tests for strategy selection and fallback boundaries.
|
||||||
|
|
||||||
|
### Task 2.2: Implement Real Streaming Path
|
||||||
|
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go)
|
||||||
|
- **Description**: Replace synthetic streaming with true upstream incremental forwarding:
|
||||||
|
- use gateway stream if available
|
||||||
|
- otherwise consume GitLab Code Suggestions streaming response and map chunks incrementally
|
||||||
|
- **Dependencies**: Task 2.1
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- `ExecuteStream` emits chunks before upstream completion.
|
||||||
|
- error handling preserves status and early failure semantics.
|
||||||
|
- **Validation**:
|
||||||
|
- tests with chunked upstream server
|
||||||
|
- manual curl check against `/v1/chat/completions` with `stream=true`
|
||||||
|
|
||||||
|
### Task 2.3: Preserve Upstream Auth And Headers Correctly
|
||||||
|
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go)
|
||||||
|
- **Description**: Use `direct_access` connection details as first-class transport state:
|
||||||
|
- gateway token
|
||||||
|
- expiry
|
||||||
|
- mandatory forwarded headers
|
||||||
|
- model metadata
|
||||||
|
- **Dependencies**: Task 2.1
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- executor stops ignoring gateway headers/token when transport requires them
|
||||||
|
- refresh logic never over-fetches `direct_access`
|
||||||
|
- **Validation**:
|
||||||
|
- tests verifying propagated headers and refresh interval behavior
|
||||||
|
|
||||||
|
## Sprint 3: Request/Response Semantics Parity
|
||||||
|
**Goal**: Make GitLab Duo behave correctly under the same request shapes that current `codex` consumers send.
|
||||||
|
|
||||||
|
**Demo/Validation**:
|
||||||
|
- OpenAI and Claude-compatible clients can do non-streaming and streaming conversations without losing structure.
|
||||||
|
|
||||||
|
### Task 3.1: Normalize Multi-Turn Message Mapping
|
||||||
|
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/translator](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/translator)
|
||||||
|
- **Description**: Replace the current "flatten prompt into one instruction" behavior with stable multi-turn mapping:
|
||||||
|
- preserve system context
|
||||||
|
- preserve user/assistant ordering
|
||||||
|
- maintain bounded context truncation
|
||||||
|
- **Dependencies**: Sprint 2
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- multi-turn requests are not collapsed into a lossy single string unless fallback mode explicitly requires it
|
||||||
|
- truncation policy is deterministic and tested
|
||||||
|
- **Validation**:
|
||||||
|
- golden tests for request mapping
|
||||||
|
|
||||||
|
### Task 3.2: Tool Calling Compatibility Layer
|
||||||
|
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go)
|
||||||
|
- **Description**: Decide and implement one of two paths:
|
||||||
|
- native pass-through if GitLab gateway supports tool/function structures
|
||||||
|
- strict downgrade path with explicit unsupported errors instead of silent field loss
|
||||||
|
- **Dependencies**: Task 3.1
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- tool-related fields are either preserved correctly or rejected explicitly
|
||||||
|
- no silent corruption of tool names, tool calls, or tool results
|
||||||
|
- **Validation**:
|
||||||
|
- table-driven tests for tool payloads
|
||||||
|
- one manual client scenario using tools
|
||||||
|
|
||||||
|
### Task 3.3: Token Counting And Usage Reporting Fidelity
|
||||||
|
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/usage_helpers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/usage_helpers.go)
|
||||||
|
- **Description**: Improve token/usage reporting so GitLab models behave like first-class providers in logs and scheduling.
|
||||||
|
- **Dependencies**: Sprint 2
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- `CountTokens` uses the closest supported estimation path
|
||||||
|
- usage logging distinguishes prompt vs completion when possible
|
||||||
|
- **Validation**:
|
||||||
|
- unit tests for token estimation outputs
|
||||||
|
|
||||||
|
## Sprint 4: Responses And Session Parity
|
||||||
|
**Goal**: Reach codex-level support for OpenAI Responses clients and long-lived sessions where GitLab upstream permits it.
|
||||||
|
|
||||||
|
**Demo/Validation**:
|
||||||
|
- `/v1/responses` works with GitLab Duo in a realistic client flow.
|
||||||
|
- If websocket parity is not possible, the code explicitly declines it and keeps HTTP paths stable.
|
||||||
|
|
||||||
|
### Task 4.1: Make GitLab Compatible With `/v1/responses`
|
||||||
|
- **Location**: [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||||
|
- **Description**: Ensure GitLab transport can safely back the Responses API path, including compact responses if applicable.
|
||||||
|
- **Dependencies**: Sprint 3
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- GitLab Duo can be selected behind `/v1/responses`
|
||||||
|
- response IDs and follow-up semantics are defined
|
||||||
|
- **Validation**:
|
||||||
|
- handler tests analogous to codex/openai responses tests
|
||||||
|
|
||||||
|
### Task 4.2: Evaluate Downstream Websocket Parity
|
||||||
|
- **Location**: [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
|
||||||
|
- **Description**: Decide whether GitLab Duo can support downstream websocket sessions like codex:
|
||||||
|
- if yes, add session-aware execution path
|
||||||
|
- if no, mark GitLab auth as websocket-ineligible and keep HTTP routes first-class
|
||||||
|
- **Dependencies**: Task 4.1
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- websocket behavior is explicit, not accidental
|
||||||
|
- no route claims websocket support when the upstream cannot honor it
|
||||||
|
- **Validation**:
|
||||||
|
- websocket handler tests or explicit capability tests
|
||||||
|
|
||||||
|
### Task 4.3: Add Session Cleanup And Failure Recovery Semantics
|
||||||
|
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/cliproxy/auth/conductor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/auth/conductor.go)
|
||||||
|
- **Description**: Add codex-like session cleanup, retry boundaries, and model suspension/resume behavior for GitLab failures and quota events.
|
||||||
|
- **Dependencies**: Sprint 2
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- auth/model cooldown behavior is predictable on GitLab 4xx/5xx/quota responses
|
||||||
|
- executor cleans up per-session resources if any are introduced
|
||||||
|
- **Validation**:
|
||||||
|
- tests for quota and retry behavior
|
||||||
|
|
||||||
|
## Sprint 5: Client UX, Model UX, And Manual E2E
|
||||||
|
**Goal**: Make GitLab Duo feel like a normal built-in provider to operators and downstream clients.
|
||||||
|
|
||||||
|
**Demo/Validation**:
|
||||||
|
- A documented setup exists for "login once, point Claude Code at CLIProxyAPI, use GitLab Duo-backed model".
|
||||||
|
|
||||||
|
### Task 5.1: Model Alias And Provider UX Cleanup
|
||||||
|
- **Location**: [sdk/cliproxy/service.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/service.go), [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
|
||||||
|
- **Description**: Normalize what users see:
|
||||||
|
- stable alias such as `gitlab-duo`
|
||||||
|
- discovered upstream model names
|
||||||
|
- optional prefix behavior
|
||||||
|
- account labels that clearly distinguish OAuth vs PAT
|
||||||
|
- **Dependencies**: Sprint 3
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- users can select a stable GitLab alias even when upstream model changes
|
||||||
|
- dynamic model discovery does not cause confusing model churn
|
||||||
|
- **Validation**:
|
||||||
|
- registry tests and manual `/v1/models` inspection
|
||||||
|
|
||||||
|
### Task 5.2: Add Real End-To-End Acceptance Tests
|
||||||
|
- **Location**: [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go), [sdk/api/handlers/openai](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai)
|
||||||
|
- **Description**: Add higher-level tests covering the actual proxy surfaces:
|
||||||
|
- OpenAI `chat/completions`
|
||||||
|
- OpenAI `responses`
|
||||||
|
- Claude-compatible request path if GitLab is routed there
|
||||||
|
- **Dependencies**: Sprint 4
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- tests fail if streaming regresses into synthetic buffering again
|
||||||
|
- tests cover at least one tool-related request and one multi-turn request
|
||||||
|
- **Validation**:
|
||||||
|
- `go test ./...`
|
||||||
|
|
||||||
|
### Task 5.3: Publish Operator Documentation
|
||||||
|
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
|
||||||
|
- **Description**: Document:
|
||||||
|
- OAuth setup requirements
|
||||||
|
- PAT requirements
|
||||||
|
- current capability matrix
|
||||||
|
- known limitations if websocket/tool parity is partial
|
||||||
|
- **Dependencies**: Sprint 5.1
|
||||||
|
- **Acceptance Criteria**:
|
||||||
|
- setup instructions are enough for a new user to reproduce the GitLab Duo flow
|
||||||
|
- limitations are explicit
|
||||||
|
- **Validation**:
|
||||||
|
- dry-run docs review from a clean environment
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
- Keep `go test ./...` green after every committable task.
|
||||||
|
- Add table-driven tests first for request mapping, refresh behavior, and dynamic model registration.
|
||||||
|
- Add transport tests with `httptest.Server` for:
|
||||||
|
- real chunked streaming
|
||||||
|
- header propagation from `direct_access`
|
||||||
|
- upstream fallback rules
|
||||||
|
- Add at least one manual acceptance checklist:
|
||||||
|
- login via OAuth
|
||||||
|
- login via PAT
|
||||||
|
- list models
|
||||||
|
- run one streaming prompt via OpenAI route
|
||||||
|
- run one prompt from the target downstream client
|
||||||
|
|
||||||
|
## Potential Risks & Gotchas
|
||||||
|
- GitLab public docs expose `direct_access`, but do not fully document every possible AI gateway path. We should isolate any empirically discovered gateway assumptions behind one transport layer and feature flags.
|
||||||
|
- `chat/completions` availability differs by GitLab offering and version. The executor must not assume it always exists.
|
||||||
|
- Code Suggestions is completion-oriented; lossy mapping from rich chat/tool payloads will make GitLab Duo feel worse than codex unless explicitly handled.
|
||||||
|
- Synthetic streaming is not good enough for codex parity and will cause regressions in interactive clients.
|
||||||
|
- Dynamic model discovery can create unstable UX if the stable alias and discovered model IDs are not separated cleanly.
|
||||||
|
- PAT auth may validate successfully while still lacking effective Duo permissions. Error reporting must surface this explicitly.
|
||||||
|
|
||||||
|
## Rollback Plan
|
||||||
|
- Keep the current basic GitLab executor behind a fallback mode until the new transport path is stable.
|
||||||
|
- If parity work destabilizes existing providers, revert only GitLab-specific executor changes and leave auth support intact.
|
||||||
|
- Preserve the stable `gitlab-duo` alias so rollback does not break client configuration.
|
||||||
2
go.mod
2
go.mod
@@ -91,8 +91,8 @@ require (
|
|||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
|
||||||
github.com/x448/float16 v0.8.4 // indirect
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -14,13 +13,12 @@ import (
|
|||||||
|
|
||||||
"github.com/fxamacker/cbor/v2"
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"golang.org/x/oauth2/google"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultAPICallTimeout = 60 * time.Second
|
const defaultAPICallTimeout = 60 * time.Second
|
||||||
@@ -725,47 +723,12 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||||
proxyStr = strings.TrimSpace(proxyStr)
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||||
if proxyStr == "" {
|
if errBuild != nil {
|
||||||
|
log.WithError(errBuild).Debug("build proxy transport failed")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return transport
|
||||||
proxyURL, errParse := url.Parse(proxyStr)
|
|
||||||
if errParse != nil {
|
|
||||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
|
||||||
log.Debug("proxy URL missing scheme/host")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if proxyURL.Scheme == "socks5" {
|
|
||||||
var proxyAuth *proxy.Auth
|
|
||||||
if proxyURL.User != nil {
|
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
|
||||||
}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &http.Transport{
|
|
||||||
Proxy: nil,
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||||
|
|||||||
@@ -2,172 +2,112 @@ package management
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type memoryAuthStore struct {
|
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
|
||||||
mu sync.Mutex
|
t.Parallel()
|
||||||
items map[string]*coreauth.Auth
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
h := &Handler{
|
||||||
_ = ctx
|
cfg: &config.Config{
|
||||||
s.mu.Lock()
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
defer s.mu.Unlock()
|
|
||||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
|
||||||
for _, a := range s.items {
|
|
||||||
out = append(out, a.Clone())
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
|
||||||
_ = ctx
|
|
||||||
if auth == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
if s.items == nil {
|
|
||||||
s.items = make(map[string]*coreauth.Auth)
|
|
||||||
}
|
|
||||||
s.items[auth.ID] = auth.Clone()
|
|
||||||
s.mu.Unlock()
|
|
||||||
return auth.ID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
|
||||||
_ = ctx
|
|
||||||
s.mu.Lock()
|
|
||||||
delete(s.items, id)
|
|
||||||
s.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
|
||||||
var callCount int
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
t.Fatalf("expected POST, got %s", r.Method)
|
|
||||||
}
|
|
||||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
|
||||||
t.Fatalf("unexpected content-type: %s", ct)
|
|
||||||
}
|
|
||||||
bodyBytes, _ := io.ReadAll(r.Body)
|
|
||||||
_ = r.Body.Close()
|
|
||||||
values, err := url.ParseQuery(string(bodyBytes))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse form: %v", err)
|
|
||||||
}
|
|
||||||
if values.Get("grant_type") != "refresh_token" {
|
|
||||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
|
||||||
}
|
|
||||||
if values.Get("refresh_token") != "rt" {
|
|
||||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
|
||||||
}
|
|
||||||
if values.Get("client_id") != antigravityOAuthClientID {
|
|
||||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
|
||||||
}
|
|
||||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
|
||||||
t.Fatalf("unexpected client_secret")
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
||||||
"access_token": "new-token",
|
|
||||||
"refresh_token": "rt2",
|
|
||||||
"expires_in": int64(3600),
|
|
||||||
"token_type": "Bearer",
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
|
|
||||||
originalURL := antigravityOAuthTokenURL
|
|
||||||
antigravityOAuthTokenURL = srv.URL
|
|
||||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
|
||||||
|
|
||||||
store := &memoryAuthStore{}
|
|
||||||
manager := coreauth.NewManager(store, nil, nil)
|
|
||||||
|
|
||||||
auth := &coreauth.Auth{
|
|
||||||
ID: "antigravity-test.json",
|
|
||||||
FileName: "antigravity-test.json",
|
|
||||||
Provider: "antigravity",
|
|
||||||
Metadata: map[string]any{
|
|
||||||
"type": "antigravity",
|
|
||||||
"access_token": "old-token",
|
|
||||||
"refresh_token": "rt",
|
|
||||||
"expires_in": int64(3600),
|
|
||||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
|
||||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
|
||||||
t.Fatalf("register auth: %v", err)
|
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
if httpTransport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errRequest != nil {
|
||||||
|
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
geminiAuth := &coreauth.Auth{
|
||||||
|
ID: "gemini:apikey:123",
|
||||||
|
Provider: "gemini",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "shared-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
compatAuth := &coreauth.Auth{
|
||||||
|
ID: "openai-compatibility:bohe:456",
|
||||||
|
Provider: "bohe",
|
||||||
|
Label: "bohe",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "shared-key",
|
||||||
|
"compat_name": "bohe",
|
||||||
|
"provider_key": "bohe",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, errRegister := manager.Register(context.Background(), geminiAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register gemini auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), compatAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register compat auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiIndex := geminiAuth.EnsureIndex()
|
||||||
|
compatIndex := compatAuth.EnsureIndex()
|
||||||
|
if geminiIndex == compatIndex {
|
||||||
|
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
h := &Handler{authManager: manager}
|
h := &Handler{authManager: manager}
|
||||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
|
||||||
if err != nil {
|
gotGemini := h.authByIndex(geminiIndex)
|
||||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
if gotGemini == nil {
|
||||||
|
t.Fatal("expected gemini auth by index")
|
||||||
}
|
}
|
||||||
if token != "new-token" {
|
if gotGemini.ID != geminiAuth.ID {
|
||||||
t.Fatalf("expected refreshed token, got %q", token)
|
t.Fatalf("authByIndex(gemini) returned %q, want %q", gotGemini.ID, geminiAuth.ID)
|
||||||
}
|
|
||||||
if callCount != 1 {
|
|
||||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, ok := manager.GetByID(auth.ID)
|
gotCompat := h.authByIndex(compatIndex)
|
||||||
if !ok || updated == nil {
|
if gotCompat == nil {
|
||||||
t.Fatalf("expected auth in manager after update")
|
t.Fatal("expected compat auth by index")
|
||||||
}
|
}
|
||||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
if gotCompat.ID != compatAuth.ID {
|
||||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
t.Fatalf("authByIndex(compat) returned %q, want %q", gotCompat.ID, compatAuth.ID)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
|
||||||
var callCount int
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
|
|
||||||
originalURL := antigravityOAuthTokenURL
|
|
||||||
antigravityOAuthTokenURL = srv.URL
|
|
||||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
|
||||||
|
|
||||||
auth := &coreauth.Auth{
|
|
||||||
ID: "antigravity-valid.json",
|
|
||||||
FileName: "antigravity-valid.json",
|
|
||||||
Provider: "antigravity",
|
|
||||||
Metadata: map[string]any{
|
|
||||||
"type": "antigravity",
|
|
||||||
"access_token": "ok-token",
|
|
||||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
h := &Handler{}
|
|
||||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
|
||||||
}
|
|
||||||
if token != "ok-token" {
|
|
||||||
t.Fatalf("expected existing token, got %q", token)
|
|
||||||
}
|
|
||||||
if callCount != 0 {
|
|
||||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUploadAuthFile_BatchMultipart(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
|
||||||
|
files := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`},
|
||||||
|
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
for _, file := range files {
|
||||||
|
part, err := writer.CreateFormFile("file", file.name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create multipart file: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||||
|
t.Fatalf("failed to write multipart content: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close multipart writer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.UploadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) {
|
||||||
|
t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"])
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
fullPath := filepath.Join(authDir, file.name)
|
||||||
|
data, err := os.ReadFile(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected uploaded file %s to exist: %v", file.name, err)
|
||||||
|
}
|
||||||
|
if string(data) != file.content {
|
||||||
|
t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auths := manager.List()
|
||||||
|
if len(auths) != len(files) {
|
||||||
|
t.Fatalf("expected %d auth entries, got %d", len(files), len(auths))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
|
||||||
|
existingName := "alpha.json"
|
||||||
|
existingContent := `{"type":"codex","email":"alpha@example.com"}`
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to seed existing auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
files := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{name: existingName, content: `{"type":"codex"`},
|
||||||
|
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
for _, file := range files {
|
||||||
|
part, err := writer.CreateFormFile("file", file.name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create multipart file: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||||
|
t.Fatalf("failed to write multipart content: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close multipart writer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.UploadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusMultiStatus {
|
||||||
|
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filepath.Join(authDir, existingName))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected existing auth file to remain readable: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != existingContent {
|
||||||
|
t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected valid auth file to be created: %v", err)
|
||||||
|
}
|
||||||
|
if string(betaData) != files[1].content {
|
||||||
|
t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAuthFile_BatchQuery(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
files := []string{"alpha.json", "beta.json"}
|
||||||
|
for _, name := range files {
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file %s: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
h.tokenStore = &memoryAuthStore{}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(
|
||||||
|
http.MethodDelete,
|
||||||
|
"/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.DeleteAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) {
|
||||||
|
t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"])
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range files {
|
||||||
|
if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
fileName := "download-user.json"
|
||||||
|
expected := []byte(`{"type":"codex"}`)
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := rec.Body.Bytes(); string(got) != string(expected) {
|
||||||
|
t.Fatalf("unexpected download content: %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
|
||||||
|
|
||||||
|
for _, name := range []string{
|
||||||
|
"../external/secret.json",
|
||||||
|
`..\\external\\secret.json`,
|
||||||
|
"nested/secret.json",
|
||||||
|
`nested\\secret.json`,
|
||||||
|
} {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tempDir, "auth")
|
||||||
|
externalDir := filepath.Join(tempDir, "external")
|
||||||
|
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(externalDir, 0o700); err != nil {
|
||||||
|
t.Fatalf("failed to create external dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
secretName := "secret.json"
|
||||||
|
secretPath := filepath.Join(externalDir, secretName)
|
||||||
|
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write external file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(
|
||||||
|
http.MethodGet,
|
||||||
|
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
|
||||||
|
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v4/user":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": 42,
|
||||||
|
"username": "gitlab-user",
|
||||||
|
"name": "GitLab User",
|
||||||
|
"email": "gitlab@example.com",
|
||||||
|
})
|
||||||
|
case "/api/v4/personal_access_tokens/self":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": 7,
|
||||||
|
"name": "management-center",
|
||||||
|
"scopes": []string{"api", "read_user"},
|
||||||
|
"user_id": 42,
|
||||||
|
})
|
||||||
|
case "/api/v4/code_suggestions/direct_access":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"base_url": "https://cloud.gitlab.example.com",
|
||||||
|
"token": "gateway-token",
|
||||||
|
"expires_at": 1893456000,
|
||||||
|
"headers": map[string]string{
|
||||||
|
"X-Gitlab-Realm": "saas",
|
||||||
|
},
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
|
||||||
|
h.tokenStore = store
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.RequestGitLabPATToken(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got := resp["status"]; got != "ok" {
|
||||||
|
t.Fatalf("status = %#v, want ok", got)
|
||||||
|
}
|
||||||
|
if got := resp["model_provider"]; got != "anthropic" {
|
||||||
|
t.Fatalf("model_provider = %#v, want anthropic", got)
|
||||||
|
}
|
||||||
|
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
store.mu.Lock()
|
||||||
|
defer store.mu.Unlock()
|
||||||
|
if len(store.items) != 1 {
|
||||||
|
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
|
||||||
|
}
|
||||||
|
var saved *coreauth.Auth
|
||||||
|
for _, item := range store.items {
|
||||||
|
saved = item
|
||||||
|
}
|
||||||
|
if saved == nil {
|
||||||
|
t.Fatal("expected saved auth record")
|
||||||
|
}
|
||||||
|
if saved.Provider != "gitlab" {
|
||||||
|
t.Fatalf("provider = %q, want gitlab", saved.Provider)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
|
||||||
|
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["model_provider"]; got != "anthropic" {
|
||||||
|
t.Fatalf("saved model_provider = %#v, want anthropic", got)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
|
||||||
|
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
state := "gitlab-state-123"
|
||||||
|
RegisterOAuthSession(state, "gitlab")
|
||||||
|
t.Cleanup(func() { CompleteOAuthSession(state) })
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.PostOAuthCallback(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
|
||||||
|
data, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read callback file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]string
|
||||||
|
if err := json.Unmarshal(data, &payload); err != nil {
|
||||||
|
t.Fatalf("decode callback payload: %v", err)
|
||||||
|
}
|
||||||
|
if got := payload["code"]; got != "test-code" {
|
||||||
|
t.Fatalf("callback code = %q, want test-code", got)
|
||||||
|
}
|
||||||
|
if got := payload["state"]; got != state {
|
||||||
|
t.Fatalf("callback state = %q, want %q", got, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
|
||||||
|
provider, err := NormalizeOAuthProvider("gitlab")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
|
||||||
|
}
|
||||||
|
if provider != "gitlab" {
|
||||||
|
t.Fatalf("provider = %q, want gitlab", provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -509,8 +509,12 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
for i := range arr {
|
for i := range arr {
|
||||||
normalizeVertexCompatKey(&arr[i])
|
normalizeVertexCompatKey(&arr[i])
|
||||||
|
if arr[i].APIKey == "" {
|
||||||
|
c.JSON(400, gin.H{"error": fmt.Sprintf("vertex-api-key[%d].api-key is required", i)})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
h.cfg.VertexCompatAPIKey = arr
|
h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...)
|
||||||
h.cfg.SanitizeVertexCompatKeys()
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -228,6 +228,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
|||||||
return "anthropic", nil
|
return "anthropic", nil
|
||||||
case "codex", "openai":
|
case "codex", "openai":
|
||||||
return "codex", nil
|
return "codex", nil
|
||||||
|
case "gitlab":
|
||||||
|
return "gitlab", nil
|
||||||
case "gemini", "google":
|
case "gemini", "google":
|
||||||
return "gemini", nil
|
return "gemini", nil
|
||||||
case "iflow", "i-flow":
|
case "iflow", "i-flow":
|
||||||
|
|||||||
49
internal/api/handlers/management/test_store_test.go
Normal file
49
internal/api/handlers/management/test_store_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memoryAuthStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
items map[string]*coreauth.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||||
|
for _, item := range s.items {
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
|
||||||
|
if auth == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.items == nil {
|
||||||
|
s.items = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
s.items[auth.ID] = auth
|
||||||
|
return auth.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
delete(s.items, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) SetBaseDir(string) {}
|
||||||
@@ -403,6 +403,20 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/gitlab/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if errStr == "" {
|
||||||
|
errStr = c.Query("error_description")
|
||||||
|
}
|
||||||
|
if state != "" {
|
||||||
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gitlab", state, code, errStr)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
|
})
|
||||||
|
|
||||||
s.engine.GET("/google/callback", func(c *gin.Context) {
|
s.engine.GET("/google/callback", func(c *gin.Context) {
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
@@ -658,6 +672,8 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
|
|
||||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||||
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
||||||
|
mgmt.GET("/gitlab-auth-url", s.mgmt.RequestGitLabToken)
|
||||||
|
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
|
||||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||||
@@ -666,6 +682,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||||
|
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
|
||||||
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
||||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ package claude
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
tls "github.com/refraction-networking/utls"
|
tls "github.com/refraction-networking/utls"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
|
|||||||
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
||||||
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||||
var dialer proxy.Dialer = proxy.Direct
|
var dialer proxy.Dialer = proxy.Direct
|
||||||
if cfg != nil && cfg.ProxyURL != "" {
|
if cfg != nil {
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
|
||||||
if err != nil {
|
if errBuild != nil {
|
||||||
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
|
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
|
||||||
} else {
|
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||||
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
dialer = proxyDialer
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
|
|
||||||
} else {
|
|
||||||
dialer = pDialer
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BaseURL = "https://copilot.tencent.com"
|
||||||
|
DefaultDomain = "www.codebuddy.cn"
|
||||||
|
UserAgent = "CLI/2.63.2 CodeBuddy/2.63.2"
|
||||||
|
|
||||||
|
codeBuddyStatePath = "/v2/plugin/auth/state"
|
||||||
|
codeBuddyTokenPath = "/v2/plugin/auth/token"
|
||||||
|
codeBuddyRefreshPath = "/v2/plugin/auth/token/refresh"
|
||||||
|
pollInterval = 5 * time.Second
|
||||||
|
maxPollDuration = 5 * time.Minute
|
||||||
|
codeLoginPending = 11217
|
||||||
|
codeSuccess = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type CodeBuddyAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodeBuddyAuth(cfg *config.Config) *CodeBuddyAuth {
|
||||||
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
httpClient = util.SetProxy(&cfg.SDKConfig, httpClient)
|
||||||
|
}
|
||||||
|
return &CodeBuddyAuth{httpClient: httpClient, cfg: cfg, baseURL: BaseURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthState holds the state and auth URL returned by the auth state API.
|
||||||
|
type AuthState struct {
|
||||||
|
State string
|
||||||
|
AuthURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchAuthState calls POST /v2/plugin/auth/state?platform=CLI to get the state and login URL.
|
||||||
|
func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) {
|
||||||
|
stateURL := fmt.Sprintf("%s%s?platform=CLI", a.baseURL, codeBuddyStatePath)
|
||||||
|
body := []byte("{}")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, stateURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-Domain", "copilot.tencent.com")
|
||||||
|
req.Header.Set("X-No-Authorization", "true")
|
||||||
|
req.Header.Set("X-No-User-Id", "true")
|
||||||
|
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||||
|
req.Header.Set("X-No-Department-Info", "true")
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
req.Header.Set("X-Request-ID", requestID)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy auth state: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to read auth state response: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data *struct {
|
||||||
|
State string `json:"state"`
|
||||||
|
AuthURL string `json:"authUrl"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to parse auth state response: %w", err)
|
||||||
|
}
|
||||||
|
if result.Code != codeSuccess {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request failed with code %d: %s", result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
if result.Data == nil || result.Data.State == "" || result.Data.AuthURL == "" {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state response missing state or authUrl")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthState{
|
||||||
|
State: result.Data.State,
|
||||||
|
AuthURL: result.Data.AuthURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type pollResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
RequestID string `json:"requestId"`
|
||||||
|
Data *struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ExpiresIn int64 `json:"expiresIn"`
|
||||||
|
TokenType string `json:"tokenType"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// doPollRequest performs a single polling request, safely reading and closing the response body
|
||||||
|
func (a *CodeBuddyAuth) doPollRequest(ctx context.Context, pollURL string) ([]byte, int, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("%w: %v", ErrTokenFetchFailed, err)
|
||||||
|
}
|
||||||
|
a.applyPollHeaders(req)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy poll: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, resp.StatusCode, fmt.Errorf("codebuddy poll: failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
return body, resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollForToken polls until the user completes browser authorization and returns auth data.
|
||||||
|
func (a *CodeBuddyAuth) PollForToken(ctx context.Context, state string) (*CodeBuddyTokenStorage, error) {
|
||||||
|
deadline := time.Now().Add(maxPollDuration)
|
||||||
|
pollURL := fmt.Sprintf("%s%s?state=%s", a.baseURL, codeBuddyTokenPath, url.QueryEscape(state))
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(pollInterval):
|
||||||
|
}
|
||||||
|
|
||||||
|
body, statusCode, err := a.doPollRequest(ctx, pollURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("codebuddy poll: request error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusCode != http.StatusOK {
|
||||||
|
log.Debugf("codebuddy poll: unexpected status %d", statusCode)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var result pollResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch result.Code {
|
||||||
|
case codeSuccess:
|
||||||
|
if result.Data == nil {
|
||||||
|
return nil, fmt.Errorf("%w: empty data in response", ErrTokenFetchFailed)
|
||||||
|
}
|
||||||
|
userID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||||
|
return &CodeBuddyTokenStorage{
|
||||||
|
AccessToken: result.Data.AccessToken,
|
||||||
|
RefreshToken: result.Data.RefreshToken,
|
||||||
|
ExpiresIn: result.Data.ExpiresIn,
|
||||||
|
TokenType: result.Data.TokenType,
|
||||||
|
Domain: result.Data.Domain,
|
||||||
|
UserID: userID,
|
||||||
|
Type: "codebuddy",
|
||||||
|
}, nil
|
||||||
|
case codeLoginPending:
|
||||||
|
// continue polling
|
||||||
|
default:
|
||||||
|
// TODO: when the CodeBuddy API error code for user denial is known,
|
||||||
|
// return ErrAccessDenied here instead of ErrTokenFetchFailed.
|
||||||
|
return nil, fmt.Errorf("%w: server returned code %d: %s", ErrTokenFetchFailed, result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrPollingTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeUserID decodes the sub field from a JWT access token as the user ID.
|
||||||
|
func (a *CodeBuddyAuth) DecodeUserID(accessToken string) (string, error) {
|
||||||
|
parts := strings.Split(accessToken, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "", ErrJWTDecodeFailed
|
||||||
|
}
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||||
|
}
|
||||||
|
var claims struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||||
|
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||||
|
}
|
||||||
|
if claims.Sub == "" {
|
||||||
|
return "", fmt.Errorf("%w: sub claim is empty", ErrJWTDecodeFailed)
|
||||||
|
}
|
||||||
|
return claims.Sub, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken exchanges a refresh token for a new access token.
|
||||||
|
// It calls POST /v2/plugin/auth/token/refresh with the required headers.
|
||||||
|
func (a *CodeBuddyAuth) RefreshToken(ctx context.Context, accessToken, refreshToken, userID, domain string) (*CodeBuddyTokenStorage, error) {
|
||||||
|
if domain == "" {
|
||||||
|
domain = DefaultDomain
|
||||||
|
}
|
||||||
|
refreshURL := fmt.Sprintf("%s%s", a.baseURL, codeBuddyRefreshPath)
|
||||||
|
body := []byte("{}")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-Domain", domain)
|
||||||
|
req.Header.Set("X-Refresh-Token", refreshToken)
|
||||||
|
req.Header.Set("X-Auth-Refresh-Source", "plugin")
|
||||||
|
req.Header.Set("X-Request-ID", requestID)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("X-User-Id", userID)
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy refresh: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to read refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh token rejected (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data *struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ExpiresIn int64 `json:"expiresIn"`
|
||||||
|
RefreshExpiresIn int64 `json:"refreshExpiresIn"`
|
||||||
|
TokenType string `json:"tokenType"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
if result.Code != codeSuccess {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh failed with code %d: %s", result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
if result.Data == nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: empty data in refresh response")
|
||||||
|
}
|
||||||
|
|
||||||
|
newUserID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||||
|
if newUserID == "" {
|
||||||
|
newUserID = userID
|
||||||
|
}
|
||||||
|
tokenDomain := result.Data.Domain
|
||||||
|
if tokenDomain == "" {
|
||||||
|
tokenDomain = domain
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CodeBuddyTokenStorage{
|
||||||
|
AccessToken: result.Data.AccessToken,
|
||||||
|
RefreshToken: result.Data.RefreshToken,
|
||||||
|
ExpiresIn: result.Data.ExpiresIn,
|
||||||
|
RefreshExpiresIn: result.Data.RefreshExpiresIn,
|
||||||
|
TokenType: result.Data.TokenType,
|
||||||
|
Domain: tokenDomain,
|
||||||
|
UserID: newUserID,
|
||||||
|
Type: "codebuddy",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *CodeBuddyAuth) applyPollHeaders(req *http.Request) {
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-No-Authorization", "true")
|
||||||
|
req.Header.Set("X-No-User-Id", "true")
|
||||||
|
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||||
|
req.Header.Set("X-No-Department-Info", "true")
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
}
|
||||||
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestAuth creates a CodeBuddyAuth pointing at the given test server.
|
||||||
|
func newTestAuth(serverURL string) *CodeBuddyAuth {
|
||||||
|
return &CodeBuddyAuth{
|
||||||
|
httpClient: http.DefaultClient,
|
||||||
|
baseURL: serverURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeJWT builds a minimal JWT with the given sub claim for testing.
|
||||||
|
func fakeJWT(sub string) string {
|
||||||
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
|
||||||
|
payload, _ := json.Marshal(map[string]any{"sub": sub, "iat": 1234567890})
|
||||||
|
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||||
|
return header + "." + encodedPayload + ".sig"
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- FetchAuthState tests ---
|
||||||
|
|
||||||
|
func TestFetchAuthState_Success(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if got := r.URL.Path; got != codeBuddyStatePath {
|
||||||
|
t.Errorf("expected path %s, got %s", codeBuddyStatePath, got)
|
||||||
|
}
|
||||||
|
if got := r.URL.Query().Get("platform"); got != "CLI" {
|
||||||
|
t.Errorf("expected platform=CLI, got %s", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("User-Agent"); got != UserAgent {
|
||||||
|
t.Errorf("expected User-Agent %s, got %s", UserAgent, got)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"state": "test-state-abc",
|
||||||
|
"authUrl": "https://example.com/login?state=test-state-abc",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
result, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.State != "test-state-abc" {
|
||||||
|
t.Errorf("expected state 'test-state-abc', got '%s'", result.State)
|
||||||
|
}
|
||||||
|
if result.AuthURL != "https://example.com/login?state=test-state-abc" {
|
||||||
|
t.Errorf("unexpected authURL: %s", result.AuthURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_NonOKStatus(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("internal error"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-200 status")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_APIErrorCode(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 10001,
|
||||||
|
"msg": "rate limited",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-zero code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_MissingData(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"state": "",
|
||||||
|
"authUrl": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty state/authUrl")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- RefreshToken tests ---
|
||||||
|
|
||||||
|
func TestRefreshToken_Success(t *testing.T) {
|
||||||
|
newAccessToken := fakeJWT("refreshed-user-456")
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if got := r.URL.Path; got != codeBuddyRefreshPath {
|
||||||
|
t.Errorf("expected path %s, got %s", codeBuddyRefreshPath, got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Refresh-Token"); got != "old-refresh-token" {
|
||||||
|
t.Errorf("expected X-Refresh-Token 'old-refresh-token', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Authorization"); got != "Bearer old-access-token" {
|
||||||
|
t.Errorf("expected Authorization 'Bearer old-access-token', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-User-Id"); got != "user-123" {
|
||||||
|
t.Errorf("expected X-User-Id 'user-123', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Domain"); got != "custom.domain.com" {
|
||||||
|
t.Errorf("expected X-Domain 'custom.domain.com', got '%s'", got)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": newAccessToken,
|
||||||
|
"refreshToken": "new-refresh-token",
|
||||||
|
"expiresIn": 3600,
|
||||||
|
"refreshExpiresIn": 86400,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": "custom.domain.com",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
storage, err := auth.RefreshToken(context.Background(), "old-access-token", "old-refresh-token", "user-123", "custom.domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if storage.AccessToken != newAccessToken {
|
||||||
|
t.Errorf("expected new access token, got '%s'", storage.AccessToken)
|
||||||
|
}
|
||||||
|
if storage.RefreshToken != "new-refresh-token" {
|
||||||
|
t.Errorf("expected 'new-refresh-token', got '%s'", storage.RefreshToken)
|
||||||
|
}
|
||||||
|
if storage.UserID != "refreshed-user-456" {
|
||||||
|
t.Errorf("expected userID 'refreshed-user-456', got '%s'", storage.UserID)
|
||||||
|
}
|
||||||
|
if storage.ExpiresIn != 3600 {
|
||||||
|
t.Errorf("expected expiresIn 3600, got %d", storage.ExpiresIn)
|
||||||
|
}
|
||||||
|
if storage.RefreshExpiresIn != 86400 {
|
||||||
|
t.Errorf("expected refreshExpiresIn 86400, got %d", storage.RefreshExpiresIn)
|
||||||
|
}
|
||||||
|
if storage.Domain != "custom.domain.com" {
|
||||||
|
t.Errorf("expected domain 'custom.domain.com', got '%s'", storage.Domain)
|
||||||
|
}
|
||||||
|
if storage.Type != "codebuddy" {
|
||||||
|
t.Errorf("expected type 'codebuddy', got '%s'", storage.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_DefaultDomain(t *testing.T) {
|
||||||
|
var receivedDomain string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedDomain = r.Header.Get("X-Domain")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": fakeJWT("user-1"),
|
||||||
|
"refreshToken": "rt",
|
||||||
|
"expiresIn": 3600,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": DefaultDomain,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if receivedDomain != DefaultDomain {
|
||||||
|
t.Errorf("expected default domain '%s', got '%s'", DefaultDomain, receivedDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_Unauthorized(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401 response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_Forbidden(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 403 response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_APIErrorCode(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 40001,
|
||||||
|
"msg": "invalid refresh token",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-zero API code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_FallbackUserIDAndDomain(t *testing.T) {
|
||||||
|
// When the new access token cannot be decoded for userID, it should fall back to the provided one.
|
||||||
|
// When the response domain is empty, it should fall back to the request domain.
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": "not-a-valid-jwt",
|
||||||
|
"refreshToken": "new-rt",
|
||||||
|
"expiresIn": 7200,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
storage, err := auth.RefreshToken(context.Background(), "at", "rt", "original-uid", "original.domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if storage.UserID != "original-uid" {
|
||||||
|
t.Errorf("expected fallback userID 'original-uid', got '%s'", storage.UserID)
|
||||||
|
}
|
||||||
|
if storage.Domain != "original.domain.com" {
|
||||||
|
t.Errorf("expected fallback domain 'original.domain.com', got '%s'", storage.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
22
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
22
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package codebuddy_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDecodeUserID_ValidJWT(t *testing.T) {
|
||||||
|
// JWT payload: {"sub":"test-user-id-123","iat":1234567890}
|
||||||
|
// base64url encode: eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ
|
||||||
|
token := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ.sig"
|
||||||
|
auth := codebuddy.NewCodeBuddyAuth(nil)
|
||||||
|
userID, err := auth.DecodeUserID(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if userID != "test-user-id-123" {
|
||||||
|
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
25
internal/auth/codebuddy/errors.go
Normal file
25
internal/auth/codebuddy/errors.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPollingTimeout = errors.New("codebuddy: polling timeout, user did not authorize in time")
|
||||||
|
ErrAccessDenied = errors.New("codebuddy: access denied by user")
|
||||||
|
ErrTokenFetchFailed = errors.New("codebuddy: failed to fetch token from server")
|
||||||
|
ErrJWTDecodeFailed = errors.New("codebuddy: failed to decode JWT token")
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetUserFriendlyMessage(err error) string {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, ErrPollingTimeout):
|
||||||
|
return "Authentication timed out. Please try again."
|
||||||
|
case errors.Is(err, ErrAccessDenied):
|
||||||
|
return "Access denied. Please try again and approve the login request."
|
||||||
|
case errors.Is(err, ErrJWTDecodeFailed):
|
||||||
|
return "Failed to decode token. Please try logging in again."
|
||||||
|
case errors.Is(err, ErrTokenFetchFailed):
|
||||||
|
return "Failed to fetch token from server. Please try again."
|
||||||
|
default:
|
||||||
|
return "Authentication failed: " + err.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
65
internal/auth/codebuddy/token.go
Normal file
65
internal/auth/codebuddy/token.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
// Package codebuddy provides authentication and token management functionality
|
||||||
|
// for CodeBuddy AI services. It handles OAuth2 token storage, serialization,
|
||||||
|
// and retrieval for maintaining authenticated sessions with the CodeBuddy API.
|
||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CodeBuddyTokenStorage stores OAuth token information for CodeBuddy API authentication.
|
||||||
|
// It maintains compatibility with the existing auth system while adding CodeBuddy-specific fields
|
||||||
|
// for managing access tokens and user account information.
|
||||||
|
type CodeBuddyTokenStorage struct {
|
||||||
|
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// ExpiresIn is the number of seconds until the access token expires.
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
// RefreshExpiresIn is the number of seconds until the refresh token expires.
|
||||||
|
RefreshExpiresIn int64 `json:"refresh_expires_in,omitempty"`
|
||||||
|
// TokenType is the type of token, typically "bearer".
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
// Domain is the CodeBuddy service domain/region.
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
// UserID is the user ID associated with this token.
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
// Type indicates the authentication provider type, always "codebuddy" for this storage.
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the CodeBuddy token storage to a JSON file.
|
||||||
|
// This method creates the necessary directory structure and writes the token
|
||||||
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - authFilePath: The full path where the token file should be saved
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the operation fails, nil otherwise
|
||||||
|
func (s *CodeBuddyTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
misc.LogSavingCredentials(authFilePath)
|
||||||
|
s.Type = "codebuddy"
|
||||||
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.OpenFile(authFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(s); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
33
internal/auth/cursor/filename.go
Normal file
33
internal/auth/cursor/filename.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package cursor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used to persist Cursor credentials.
|
||||||
|
// Priority: explicit label > auto-generated from JWT sub hash.
|
||||||
|
// If both label and subHash are empty, falls back to "cursor.json".
|
||||||
|
func CredentialFileName(label, subHash string) string {
|
||||||
|
label = strings.TrimSpace(label)
|
||||||
|
subHash = strings.TrimSpace(subHash)
|
||||||
|
if label != "" {
|
||||||
|
return fmt.Sprintf("cursor.%s.json", label)
|
||||||
|
}
|
||||||
|
if subHash != "" {
|
||||||
|
return fmt.Sprintf("cursor.%s.json", subHash)
|
||||||
|
}
|
||||||
|
return "cursor.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayLabel returns a human-readable label for the Cursor account.
|
||||||
|
func DisplayLabel(label, subHash string) string {
|
||||||
|
label = strings.TrimSpace(label)
|
||||||
|
if label != "" {
|
||||||
|
return "Cursor " + label
|
||||||
|
}
|
||||||
|
if subHash != "" {
|
||||||
|
return "Cursor " + subHash
|
||||||
|
}
|
||||||
|
return "Cursor User"
|
||||||
|
}
|
||||||
249
internal/auth/cursor/oauth.go
Normal file
249
internal/auth/cursor/oauth.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
|
||||||
|
package cursor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CursorLoginURL = "https://cursor.com/loginDeepControl"
|
||||||
|
CursorPollURL = "https://api2.cursor.sh/auth/poll"
|
||||||
|
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
|
||||||
|
|
||||||
|
pollMaxAttempts = 150
|
||||||
|
pollBaseDelay = 1 * time.Second
|
||||||
|
pollMaxDelay = 10 * time.Second
|
||||||
|
pollBackoffMultiply = 1.2
|
||||||
|
maxConsecutiveErrors = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthParams holds the PKCE parameters for Cursor login.
|
||||||
|
type AuthParams struct {
|
||||||
|
Verifier string
|
||||||
|
Challenge string
|
||||||
|
UUID string
|
||||||
|
LoginURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenPair holds the access and refresh tokens from Cursor.
|
||||||
|
type TokenPair struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeneratePKCE creates a PKCE verifier and challenge pair.
|
||||||
|
func GeneratePKCE() (verifier, challenge string, err error) {
|
||||||
|
verifierBytes := make([]byte, 96)
|
||||||
|
if _, err = rand.Read(verifierBytes); err != nil {
|
||||||
|
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
|
||||||
|
}
|
||||||
|
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||||
|
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
return verifier, challenge, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAuthParams creates the full set of auth params for Cursor login.
|
||||||
|
func GenerateAuthParams() (*AuthParams, error) {
|
||||||
|
verifier, challenge, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
uuidBytes := make([]byte, 16)
|
||||||
|
if _, err = rand.Read(uuidBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
|
||||||
|
}
|
||||||
|
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||||
|
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
|
||||||
|
|
||||||
|
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
|
||||||
|
CursorLoginURL, challenge, uuid)
|
||||||
|
|
||||||
|
return &AuthParams{
|
||||||
|
Verifier: verifier,
|
||||||
|
Challenge: challenge,
|
||||||
|
UUID: uuid,
|
||||||
|
LoginURL: loginURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollForAuth polls the Cursor auth endpoint until the user completes login.
|
||||||
|
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
|
||||||
|
delay := pollBaseDelay
|
||||||
|
consecutiveErrors := 0
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|
||||||
|
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(delay):
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
consecutiveErrors++
|
||||||
|
if consecutiveErrors >= maxConsecutiveErrors {
|
||||||
|
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
|
||||||
|
}
|
||||||
|
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
// Still waiting for user to authorize
|
||||||
|
consecutiveErrors = 0
|
||||||
|
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
|
var tokens TokenPair
|
||||||
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
|
||||||
|
}
|
||||||
|
return &tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
|
||||||
|
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken refreshes a Cursor access token using the refresh token.
|
||||||
|
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
|
||||||
|
strings.NewReader("{}"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+refreshToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens TokenPair
|
||||||
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep original refresh token if not returned
|
||||||
|
if tokens.RefreshToken == "" {
|
||||||
|
tokens.RefreshToken = refreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
|
||||||
|
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
|
||||||
|
// the account. Returns empty string if parsing fails.
|
||||||
|
func ParseJWTSub(token string) string {
|
||||||
|
decoded := decodeJWTPayload(token)
|
||||||
|
if decoded == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var claims struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return claims.Sub
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
|
||||||
|
// e.g. "auth0|user_2x..." → "a3f8b2c1"
|
||||||
|
func SubToShortHash(sub string) string {
|
||||||
|
if sub == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
h := sha256.Sum256([]byte(sub))
|
||||||
|
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeJWTPayload decodes the payload (middle) part of a JWT.
|
||||||
|
func decodeJWTPayload(token string) []byte {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload := parts[1]
|
||||||
|
switch len(payload) % 4 {
|
||||||
|
case 2:
|
||||||
|
payload += "=="
|
||||||
|
case 3:
|
||||||
|
payload += "="
|
||||||
|
}
|
||||||
|
payload = strings.ReplaceAll(payload, "-", "+")
|
||||||
|
payload = strings.ReplaceAll(payload, "_", "/")
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
|
||||||
|
// Falls back to 1 hour from now if the token can't be parsed.
|
||||||
|
func GetTokenExpiry(token string) time.Time {
|
||||||
|
decoded := decodeJWTPayload(token)
|
||||||
|
if decoded == nil {
|
||||||
|
return time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims struct {
|
||||||
|
Exp float64 `json:"exp"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
|
||||||
|
return time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
sec, frac := math.Modf(claims.Exp)
|
||||||
|
expiry := time.Unix(int64(sec), int64(frac*1e9))
|
||||||
|
// Subtract 5-minute safety margin
|
||||||
|
return expiry.Add(-5 * time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func minDuration(a, b time.Duration) time.Duration {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
84
internal/auth/cursor/proto/connect.go
Normal file
84
internal/auth/cursor/proto/connect.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
|
||||||
|
ConnectEndStreamFlag byte = 0x02
|
||||||
|
// ConnectCompressionFlag indicates the payload is compressed (not supported).
|
||||||
|
ConnectCompressionFlag byte = 0x01
|
||||||
|
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
|
||||||
|
ConnectFrameHeaderSize = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
|
||||||
|
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
|
||||||
|
func FrameConnectMessage(data []byte, flags byte) []byte {
|
||||||
|
frame := make([]byte, ConnectFrameHeaderSize+len(data))
|
||||||
|
frame[0] = flags
|
||||||
|
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
|
||||||
|
copy(frame[5:], data)
|
||||||
|
return frame
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConnectFrame extracts one frame from a buffer.
|
||||||
|
// Returns (flags, payload, bytesConsumed, ok).
|
||||||
|
// ok is false when the buffer is too short for a complete frame.
|
||||||
|
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
|
||||||
|
if len(buf) < ConnectFrameHeaderSize {
|
||||||
|
return 0, nil, 0, false
|
||||||
|
}
|
||||||
|
flags = buf[0]
|
||||||
|
length := binary.BigEndian.Uint32(buf[1:5])
|
||||||
|
total := ConnectFrameHeaderSize + int(length)
|
||||||
|
if len(buf) < total {
|
||||||
|
return 0, nil, 0, false
|
||||||
|
}
|
||||||
|
return flags, buf[5:total], total, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
|
||||||
|
// The Code field contains the server-defined error code (e.g. gRPC standard codes
|
||||||
|
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
|
||||||
|
type ConnectError struct {
|
||||||
|
Code string // server-defined error code
|
||||||
|
Message string // human-readable error description
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectError) Error() string {
|
||||||
|
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
|
||||||
|
// Returns nil if there is no error in the trailer.
|
||||||
|
// On error, returns a *ConnectError with the server's error code and message.
|
||||||
|
func ParseConnectEndStream(data []byte) error {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var trailer struct {
|
||||||
|
Error *struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &trailer); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse Connect end stream: %w", err)
|
||||||
|
}
|
||||||
|
if trailer.Error != nil {
|
||||||
|
code := trailer.Error.Code
|
||||||
|
if code == "" {
|
||||||
|
code = "unknown"
|
||||||
|
}
|
||||||
|
msg := trailer.Error.Message
|
||||||
|
if msg == "" {
|
||||||
|
msg = "Unknown error"
|
||||||
|
}
|
||||||
|
return &ConnectError{Code: code, Message: msg}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
564
internal/auth/cursor/proto/decode.go
Normal file
564
internal/auth/cursor/proto/decode.go
Normal file
@@ -0,0 +1,564 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/encoding/protowire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerMessageType identifies the kind of decoded server message.
|
||||||
|
type ServerMessageType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServerMsgUnknown ServerMessageType = iota
|
||||||
|
ServerMsgTextDelta // Text content delta
|
||||||
|
ServerMsgThinkingDelta // Thinking/reasoning delta
|
||||||
|
ServerMsgThinkingCompleted // Thinking completed
|
||||||
|
ServerMsgKvGetBlob // Server wants a blob
|
||||||
|
ServerMsgKvSetBlob // Server wants to store a blob
|
||||||
|
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
|
||||||
|
ServerMsgExecMcpArgs // Server wants MCP tool execution
|
||||||
|
ServerMsgExecShellArgs // Rejected: shell command
|
||||||
|
ServerMsgExecReadArgs // Rejected: file read
|
||||||
|
ServerMsgExecWriteArgs // Rejected: file write
|
||||||
|
ServerMsgExecDeleteArgs // Rejected: file delete
|
||||||
|
ServerMsgExecLsArgs // Rejected: directory listing
|
||||||
|
ServerMsgExecGrepArgs // Rejected: grep search
|
||||||
|
ServerMsgExecFetchArgs // Rejected: HTTP fetch
|
||||||
|
ServerMsgExecDiagnostics // Respond with empty diagnostics
|
||||||
|
ServerMsgExecShellStream // Rejected: shell stream
|
||||||
|
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||||
|
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||||
|
ServerMsgExecOther // Other exec types (respond with empty)
|
||||||
|
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||||
|
ServerMsgHeartbeat // Server heartbeat
|
||||||
|
ServerMsgTokenDelta // Token usage delta
|
||||||
|
ServerMsgCheckpoint // Conversation checkpoint update
|
||||||
|
)
|
||||||
|
|
||||||
|
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||||
|
type DecodedServerMessage struct {
|
||||||
|
Type ServerMessageType
|
||||||
|
|
||||||
|
// For text/thinking deltas
|
||||||
|
Text string
|
||||||
|
|
||||||
|
// For KV messages
|
||||||
|
KvId uint32
|
||||||
|
BlobId []byte // hex-encoded blob ID
|
||||||
|
BlobData []byte // for setBlobArgs
|
||||||
|
|
||||||
|
// For exec messages
|
||||||
|
ExecMsgId uint32
|
||||||
|
ExecId string
|
||||||
|
|
||||||
|
// For MCP args
|
||||||
|
McpToolName string
|
||||||
|
McpToolCallId string
|
||||||
|
McpArgs map[string][]byte // arg name -> protobuf-encoded value
|
||||||
|
|
||||||
|
// For rejection context
|
||||||
|
Path string
|
||||||
|
Command string
|
||||||
|
WorkingDirectory string
|
||||||
|
Url string
|
||||||
|
|
||||||
|
// For other exec - the raw field number for building a response
|
||||||
|
ExecFieldNumber int
|
||||||
|
|
||||||
|
// For TokenDeltaUpdate
|
||||||
|
TokenDelta int64
|
||||||
|
|
||||||
|
// For conversation checkpoint update (raw bytes, not decoded)
|
||||||
|
CheckpointData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
||||||
|
// a structured representation of the first meaningful message found.
|
||||||
|
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
|
||||||
|
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
|
||||||
|
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid tag")
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid bytes field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
// Debug: log top-level ASM fields
|
||||||
|
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case ASM_InteractionUpdate:
|
||||||
|
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
|
||||||
|
decodeInteractionUpdate(val, msg)
|
||||||
|
case ASM_ExecServerMessage:
|
||||||
|
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
|
||||||
|
decodeExecServerMessage(val, msg)
|
||||||
|
case ASM_KvServerMessage:
|
||||||
|
decodeKvServerMessage(val, msg)
|
||||||
|
case ASM_ConversationCheckpoint:
|
||||||
|
msg.Type = ServerMsgCheckpoint
|
||||||
|
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
|
||||||
|
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.VarintType:
|
||||||
|
_, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid varint field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Skip unknown wire types
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
|
||||||
|
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case IU_TextDelta:
|
||||||
|
msg.Type = ServerMsgTextDelta
|
||||||
|
msg.Text = decodeStringField(val, TDU_Text)
|
||||||
|
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
|
||||||
|
case IU_ThinkingDelta:
|
||||||
|
msg.Type = ServerMsgThinkingDelta
|
||||||
|
msg.Text = decodeStringField(val, TKD_Text)
|
||||||
|
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
|
||||||
|
case IU_ThinkingCompleted:
|
||||||
|
msg.Type = ServerMsgThinkingCompleted
|
||||||
|
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
|
||||||
|
case 2:
|
||||||
|
// tool_call_started - ignore but log
|
||||||
|
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
|
||||||
|
case 3:
|
||||||
|
// tool_call_completed - ignore but log
|
||||||
|
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
||||||
|
case 8:
|
||||||
|
// token_delta - extract token count
|
||||||
|
msg.Type = ServerMsgTokenDelta
|
||||||
|
msg.TokenDelta = decodeVarintField(val, 1)
|
||||||
|
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
|
||||||
|
case 13:
|
||||||
|
// heartbeat from server
|
||||||
|
msg.Type = ServerMsgHeartbeat
|
||||||
|
case 14:
|
||||||
|
// turn_ended - critical: model finished generating
|
||||||
|
msg.Type = ServerMsgTurnEnded
|
||||||
|
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
|
||||||
|
case 16:
|
||||||
|
// step_started - ignore
|
||||||
|
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
|
||||||
|
case 17:
|
||||||
|
// step_completed - ignore
|
||||||
|
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
|
||||||
|
default:
|
||||||
|
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.VarintType:
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == KSM_Id {
|
||||||
|
msg.KvId = uint32(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case KSM_GetBlobArgs:
|
||||||
|
msg.Type = ServerMsgKvGetBlob
|
||||||
|
msg.BlobId = decodeBytesField(val, GBA_BlobId)
|
||||||
|
case KSM_SetBlobArgs:
|
||||||
|
msg.Type = ServerMsgKvSetBlob
|
||||||
|
decodeSetBlobArgs(val, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
switch num {
|
||||||
|
case SBA_BlobId:
|
||||||
|
msg.BlobId = val
|
||||||
|
case SBA_BlobData:
|
||||||
|
msg.BlobData = val
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.VarintType:
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == ESM_Id {
|
||||||
|
msg.ExecMsgId = uint32(val)
|
||||||
|
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
// Debug: log all fields found in ExecServerMessage
|
||||||
|
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case ESM_ExecId:
|
||||||
|
msg.ExecId = string(val)
|
||||||
|
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
|
||||||
|
case ESM_RequestContextArgs:
|
||||||
|
msg.Type = ServerMsgExecRequestCtx
|
||||||
|
case ESM_McpArgs:
|
||||||
|
msg.Type = ServerMsgExecMcpArgs
|
||||||
|
decodeMcpArgs(val, msg)
|
||||||
|
case ESM_ShellArgs:
|
||||||
|
msg.Type = ServerMsgExecShellArgs
|
||||||
|
decodeShellArgs(val, msg)
|
||||||
|
case ESM_ShellStreamArgs:
|
||||||
|
msg.Type = ServerMsgExecShellStream
|
||||||
|
decodeShellArgs(val, msg)
|
||||||
|
case ESM_ReadArgs:
|
||||||
|
msg.Type = ServerMsgExecReadArgs
|
||||||
|
msg.Path = decodeStringField(val, RA_Path)
|
||||||
|
case ESM_WriteArgs:
|
||||||
|
msg.Type = ServerMsgExecWriteArgs
|
||||||
|
msg.Path = decodeStringField(val, WA_Path)
|
||||||
|
case ESM_DeleteArgs:
|
||||||
|
msg.Type = ServerMsgExecDeleteArgs
|
||||||
|
msg.Path = decodeStringField(val, DA_Path)
|
||||||
|
case ESM_LsArgs:
|
||||||
|
msg.Type = ServerMsgExecLsArgs
|
||||||
|
msg.Path = decodeStringField(val, LA_Path)
|
||||||
|
case ESM_GrepArgs:
|
||||||
|
msg.Type = ServerMsgExecGrepArgs
|
||||||
|
case ESM_FetchArgs:
|
||||||
|
msg.Type = ServerMsgExecFetchArgs
|
||||||
|
msg.Url = decodeStringField(val, FA_Url)
|
||||||
|
case ESM_DiagnosticsArgs:
|
||||||
|
msg.Type = ServerMsgExecDiagnostics
|
||||||
|
case ESM_BackgroundShellSpawn:
|
||||||
|
msg.Type = ServerMsgExecBgShellSpawn
|
||||||
|
decodeShellArgs(val, msg) // same structure
|
||||||
|
case ESM_WriteShellStdinArgs:
|
||||||
|
msg.Type = ServerMsgExecWriteShellStdin
|
||||||
|
default:
|
||||||
|
// Unknown exec types - only set if we haven't identified the type yet
|
||||||
|
// (other fields like span_context (19) come after the exec type field)
|
||||||
|
if msg.Type == ServerMsgUnknown {
|
||||||
|
msg.Type = ServerMsgExecOther
|
||||||
|
msg.ExecFieldNumber = int(num)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
msg.McpArgs = make(map[string][]byte)
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case MCA_Name:
|
||||||
|
msg.McpToolName = string(val)
|
||||||
|
case MCA_Args:
|
||||||
|
// Map entries are encoded as submessages with key=1, value=2
|
||||||
|
decodeMapEntry(val, msg.McpArgs)
|
||||||
|
case MCA_ToolCallId:
|
||||||
|
msg.McpToolCallId = string(val)
|
||||||
|
case MCA_ToolName:
|
||||||
|
// ToolName takes precedence if present
|
||||||
|
if msg.McpToolName == "" || string(val) != "" {
|
||||||
|
msg.McpToolName = string(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeMapEntry(data []byte, m map[string][]byte) {
|
||||||
|
var key string
|
||||||
|
var value []byte
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == 1 {
|
||||||
|
key = string(val)
|
||||||
|
} else if num == 2 {
|
||||||
|
value = append([]byte(nil), val...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if key != "" {
|
||||||
|
m[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
switch num {
|
||||||
|
case SHA_Command:
|
||||||
|
msg.Command = string(val)
|
||||||
|
case SHA_WorkingDirectory:
|
||||||
|
msg.WorkingDirectory = string(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper decoders ---
|
||||||
|
|
||||||
|
// decodeStringField extracts a string from the first matching field in a submessage.
|
||||||
|
func decodeStringField(data []byte, targetField protowire.Number) string {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return string(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeBytesField extracts bytes from the first matching field in a submessage.
|
||||||
|
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return append([]byte(nil), val...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
|
||||||
|
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if typ == protowire.VarintType {
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return int64(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
||||||
|
func BlobIdHex(blobId []byte) string {
|
||||||
|
return hex.EncodeToString(blobId)
|
||||||
|
}
|
||||||
|
|
||||||
1244
internal/auth/cursor/proto/descriptor.go
Normal file
1244
internal/auth/cursor/proto/descriptor.go
Normal file
File diff suppressed because it is too large
Load Diff
664
internal/auth/cursor/proto/encode.go
Normal file
664
internal/auth/cursor/proto/encode.go
Normal file
@@ -0,0 +1,664 @@
|
|||||||
|
// Package proto provides protobuf encoding for Cursor's gRPC API,
|
||||||
|
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
|
||||||
|
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
|
||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/encoding/protowire"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
"google.golang.org/protobuf/types/dynamicpb"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Public types ---
|
||||||
|
|
||||||
|
// RunRequestParams holds all data needed to build an AgentRunRequest.
|
||||||
|
type RunRequestParams struct {
|
||||||
|
ModelId string
|
||||||
|
SystemPrompt string
|
||||||
|
UserText string
|
||||||
|
MessageId string
|
||||||
|
ConversationId string
|
||||||
|
Images []ImageData
|
||||||
|
Turns []TurnData
|
||||||
|
McpTools []McpToolDef
|
||||||
|
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
|
||||||
|
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageData struct {
|
||||||
|
MimeType string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type TurnData struct {
|
||||||
|
UserText string
|
||||||
|
AssistantText string
|
||||||
|
}
|
||||||
|
|
||||||
|
type McpToolDef struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
InputSchema json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper: create a dynamic message and set fields ---
|
||||||
|
|
||||||
|
func newMsg(name string) *dynamicpb.Message {
|
||||||
|
return dynamicpb.NewMessage(Msg(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
|
||||||
|
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setStr(msg *dynamicpb.Message, name, val string) {
|
||||||
|
if val != "" {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
|
||||||
|
if len(val) > 0 {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBool(msg *dynamicpb.Message, name string, val bool) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshal(msg *dynamicpb.Message) []byte {
|
||||||
|
b, err := proto.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
panic("cursor proto marshal: " + err.Error())
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Encode functions mirroring cursor-fetch.ts ---
|
||||||
|
|
||||||
|
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
|
||||||
|
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
|
||||||
|
func EncodeHeartbeat() []byte {
|
||||||
|
hb := newMsg("ClientHeartbeat")
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "client_heartbeat", hb)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
|
||||||
|
// Mirrors buildCursorRequest() in cursor-fetch.ts.
|
||||||
|
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
|
||||||
|
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
|
||||||
|
func EncodeRunRequest(p *RunRequestParams) []byte {
|
||||||
|
if p.RawCheckpoint != nil {
|
||||||
|
return encodeRunRequestWithCheckpoint(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.BlobStore == nil {
|
||||||
|
p.BlobStore = make(map[string][]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Conversation turns ---
|
||||||
|
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
|
||||||
|
var turnBytes [][]byte
|
||||||
|
for _, turn := range p.Turns {
|
||||||
|
// UserMessage for this turn
|
||||||
|
um := newMsg("UserMessage")
|
||||||
|
setStr(um, "text", turn.UserText)
|
||||||
|
setStr(um, "message_id", generateId())
|
||||||
|
umBytes := marshal(um)
|
||||||
|
|
||||||
|
// Steps (assistant response)
|
||||||
|
var stepBytes [][]byte
|
||||||
|
if turn.AssistantText != "" {
|
||||||
|
am := newMsg("AssistantMessage")
|
||||||
|
setStr(am, "text", turn.AssistantText)
|
||||||
|
step := newMsg("ConversationStep")
|
||||||
|
setMsg(step, "assistant_message", am)
|
||||||
|
stepBytes = append(stepBytes, marshal(step))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AgentConversationTurnStructure (fields are bytes, not submessages)
|
||||||
|
agentTurn := newMsg("AgentConversationTurnStructure")
|
||||||
|
setBytes(agentTurn, "user_message", umBytes)
|
||||||
|
for _, sb := range stepBytes {
|
||||||
|
stepsField := field(agentTurn, "steps")
|
||||||
|
list := agentTurn.Mutable(stepsField).List()
|
||||||
|
list.Append(protoreflect.ValueOfBytes(sb))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationTurnStructure (oneof turn → agentConversationTurn)
|
||||||
|
cts := newMsg("ConversationTurnStructure")
|
||||||
|
setMsg(cts, "agent_conversation_turn", agentTurn)
|
||||||
|
turnBytes = append(turnBytes, marshal(cts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- System prompt blob ---
|
||||||
|
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
|
||||||
|
blobId := sha256Sum(systemJSON)
|
||||||
|
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
|
||||||
|
|
||||||
|
// --- ConversationStateStructure ---
|
||||||
|
css := newMsg("ConversationStateStructure")
|
||||||
|
// rootPromptMessagesJson: repeated bytes
|
||||||
|
rootField := field(css, "root_prompt_messages_json")
|
||||||
|
rootList := css.Mutable(rootField).List()
|
||||||
|
rootList.Append(protoreflect.ValueOfBytes(blobId))
|
||||||
|
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
|
||||||
|
turnsField := field(css, "turns")
|
||||||
|
turnsList := css.Mutable(turnsField).List()
|
||||||
|
for _, tb := range turnBytes {
|
||||||
|
turnsList.Append(protoreflect.ValueOfBytes(tb))
|
||||||
|
}
|
||||||
|
turnsOldField := field(css, "turns_old")
|
||||||
|
if turnsOldField != nil {
|
||||||
|
turnsOldList := css.Mutable(turnsOldField).List()
|
||||||
|
for _, tb := range turnBytes {
|
||||||
|
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UserMessage (current) ---
|
||||||
|
userMessage := newMsg("UserMessage")
|
||||||
|
setStr(userMessage, "text", p.UserText)
|
||||||
|
setStr(userMessage, "message_id", p.MessageId)
|
||||||
|
|
||||||
|
// Images via SelectedContext
|
||||||
|
if len(p.Images) > 0 {
|
||||||
|
sc := newMsg("SelectedContext")
|
||||||
|
imgsField := field(sc, "selected_images")
|
||||||
|
imgsList := sc.Mutable(imgsField).List()
|
||||||
|
for _, img := range p.Images {
|
||||||
|
si := newMsg("SelectedImage")
|
||||||
|
setStr(si, "uuid", generateId())
|
||||||
|
setStr(si, "mime_type", img.MimeType)
|
||||||
|
setBytes(si, "data", img.Data)
|
||||||
|
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(userMessage, "selected_context", sc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UserMessageAction ---
|
||||||
|
uma := newMsg("UserMessageAction")
|
||||||
|
setMsg(uma, "user_message", userMessage)
|
||||||
|
|
||||||
|
// --- ConversationAction ---
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "user_message_action", uma)
|
||||||
|
|
||||||
|
// --- ModelDetails ---
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
|
||||||
|
// --- AgentRunRequest ---
|
||||||
|
arr := newMsg("AgentRunRequest")
|
||||||
|
setMsg(arr, "conversation_state", css)
|
||||||
|
setMsg(arr, "action", ca)
|
||||||
|
setMsg(arr, "model_details", md)
|
||||||
|
setStr(arr, "conversation_id", p.ConversationId)
|
||||||
|
|
||||||
|
// McpTools
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(arr, "mcp_tools", mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- AgentClientMessage ---
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "run_request", arr)
|
||||||
|
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
|
||||||
|
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
|
||||||
|
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
|
||||||
|
// Build UserMessage
|
||||||
|
userMessage := newMsg("UserMessage")
|
||||||
|
setStr(userMessage, "text", p.UserText)
|
||||||
|
setStr(userMessage, "message_id", p.MessageId)
|
||||||
|
if len(p.Images) > 0 {
|
||||||
|
sc := newMsg("SelectedContext")
|
||||||
|
imgsField := field(sc, "selected_images")
|
||||||
|
imgsList := sc.Mutable(imgsField).List()
|
||||||
|
for _, img := range p.Images {
|
||||||
|
si := newMsg("SelectedImage")
|
||||||
|
setStr(si, "uuid", generateId())
|
||||||
|
setStr(si, "mime_type", img.MimeType)
|
||||||
|
setBytes(si, "data", img.Data)
|
||||||
|
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(userMessage, "selected_context", sc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build ConversationAction with UserMessageAction
|
||||||
|
uma := newMsg("UserMessageAction")
|
||||||
|
setMsg(uma, "user_message", userMessage)
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "user_message_action", uma)
|
||||||
|
caBytes := marshal(ca)
|
||||||
|
|
||||||
|
// Build ModelDetails
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
mdBytes := marshal(md)
|
||||||
|
|
||||||
|
// Build McpTools
|
||||||
|
var mcpToolsBytes []byte
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
mcpToolsBytes = marshal(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
|
||||||
|
var arrBuf []byte
|
||||||
|
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
|
||||||
|
// field 2: action = ConversationAction
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
|
||||||
|
// field 3: model_details = ModelDetails
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
|
||||||
|
// field 4: mcp_tools = McpTools
|
||||||
|
if len(mcpToolsBytes) > 0 {
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
|
||||||
|
}
|
||||||
|
// field 5: conversation_id = string
|
||||||
|
if p.ConversationId != "" {
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap in AgentClientMessage field 1 (run_request)
|
||||||
|
var acmBuf []byte
|
||||||
|
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
|
||||||
|
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
|
||||||
|
|
||||||
|
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
|
||||||
|
return acmBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResumeRequestParams holds data for a ResumeAction request.
|
||||||
|
type ResumeRequestParams struct {
|
||||||
|
ModelId string
|
||||||
|
ConversationId string
|
||||||
|
McpTools []McpToolDef
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
|
||||||
|
// Used to resume a conversation by conversation_id without re-sending full history.
|
||||||
|
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
|
||||||
|
// RequestContext with tools
|
||||||
|
rc := newMsg("RequestContext")
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
toolsField := field(rc, "tools")
|
||||||
|
toolsList := rc.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResumeAction
|
||||||
|
ra := newMsg("ResumeAction")
|
||||||
|
setMsg(ra, "request_context", rc)
|
||||||
|
|
||||||
|
// ConversationAction with resume_action
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "resume_action", ra)
|
||||||
|
|
||||||
|
// ModelDetails
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
|
||||||
|
// AgentRunRequest — no conversation_state needed for resume
|
||||||
|
arr := newMsg("AgentRunRequest")
|
||||||
|
setMsg(arr, "action", ca)
|
||||||
|
setMsg(arr, "model_details", md)
|
||||||
|
setStr(arr, "conversation_id", p.ConversationId)
|
||||||
|
|
||||||
|
// McpTools at top level
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(arr, "mcp_tools", mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "run_request", arr)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- KV response encoders ---
|
||||||
|
// Mirrors handleKvMessage() in cursor-fetch.ts
|
||||||
|
|
||||||
|
// EncodeKvGetBlobResult responds to a getBlobArgs request.
|
||||||
|
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
|
||||||
|
result := newMsg("GetBlobResult")
|
||||||
|
if blobData != nil {
|
||||||
|
setBytes(result, "blob_data", blobData)
|
||||||
|
}
|
||||||
|
|
||||||
|
kvc := newMsg("KvClientMessage")
|
||||||
|
setUint32(kvc, "id", kvId)
|
||||||
|
setMsg(kvc, "get_blob_result", result)
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "kv_client_message", kvc)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeKvSetBlobResult responds to a setBlobArgs request.
|
||||||
|
func EncodeKvSetBlobResult(kvId uint32) []byte {
|
||||||
|
result := newMsg("SetBlobResult")
|
||||||
|
|
||||||
|
kvc := newMsg("KvClientMessage")
|
||||||
|
setUint32(kvc, "id", kvId)
|
||||||
|
setMsg(kvc, "set_blob_result", result)
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "kv_client_message", kvc)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Exec response encoders ---
|
||||||
|
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
|
||||||
|
|
||||||
|
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
|
||||||
|
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
|
||||||
|
// RequestContext with tools
|
||||||
|
rc := newMsg("RequestContext")
|
||||||
|
if len(tools) > 0 {
|
||||||
|
toolsField := field(rc, "tools")
|
||||||
|
toolsList := rc.Mutable(toolsField).List()
|
||||||
|
for _, tool := range tools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestContextSuccess
|
||||||
|
rcs := newMsg("RequestContextSuccess")
|
||||||
|
setMsg(rcs, "request_context", rc)
|
||||||
|
|
||||||
|
// RequestContextResult (oneof success)
|
||||||
|
rcr := newMsg("RequestContextResult")
|
||||||
|
setMsg(rcr, "success", rcs)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeExecMcpResult responds with MCP tool result.
|
||||||
|
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
|
||||||
|
textContent := newMsg("McpTextContent")
|
||||||
|
setStr(textContent, "text", content)
|
||||||
|
|
||||||
|
contentItem := newMsg("McpToolResultContentItem")
|
||||||
|
setMsg(contentItem, "text", textContent)
|
||||||
|
|
||||||
|
success := newMsg("McpSuccess")
|
||||||
|
contentField := field(success, "content")
|
||||||
|
contentList := success.Mutable(contentField).List()
|
||||||
|
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
|
||||||
|
setBool(success, "is_error", isError)
|
||||||
|
|
||||||
|
result := newMsg("McpResult")
|
||||||
|
setMsg(result, "success", success)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeExecMcpError responds with MCP error.
|
||||||
|
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
mcpErr := newMsg("McpError")
|
||||||
|
setStr(mcpErr, "error", errMsg)
|
||||||
|
|
||||||
|
result := newMsg("McpResult")
|
||||||
|
setMsg(result, "error", mcpErr)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Rejection encoders (mirror handleExecMessage rejections) ---
|
||||||
|
|
||||||
|
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("ReadRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("ReadResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||||
|
rej := newMsg("ShellRejected")
|
||||||
|
setStr(rej, "command", command)
|
||||||
|
setStr(rej, "working_directory", workDir)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("ShellResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("WriteRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("WriteResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("DeleteRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("DeleteResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("LsRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("LsResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
grepErr := newMsg("GrepError")
|
||||||
|
setStr(grepErr, "error", errMsg)
|
||||||
|
result := newMsg("GrepResult")
|
||||||
|
setMsg(result, "error", grepErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
|
||||||
|
fetchErr := newMsg("FetchError")
|
||||||
|
setStr(fetchErr, "url", url)
|
||||||
|
setStr(fetchErr, "error", errMsg)
|
||||||
|
result := newMsg("FetchResult")
|
||||||
|
setMsg(result, "error", fetchErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
|
||||||
|
result := newMsg("DiagnosticsResult")
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||||
|
rej := newMsg("ShellRejected")
|
||||||
|
setStr(rej, "command", command)
|
||||||
|
setStr(rej, "working_directory", workDir)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("BackgroundShellSpawnResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
wsErr := newMsg("WriteShellStdinError")
|
||||||
|
setStr(wsErr, "error", errMsg)
|
||||||
|
result := newMsg("WriteShellStdinResult")
|
||||||
|
setMsg(result, "error", wsErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
|
||||||
|
// Mirrors sendExec() in cursor-fetch.ts.
|
||||||
|
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
|
||||||
|
ecm := newMsg("ExecClientMessage")
|
||||||
|
setUint32(ecm, "id", id)
|
||||||
|
// Force set exec_id even if empty - Cursor requires this field to be set
|
||||||
|
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
|
||||||
|
|
||||||
|
// Debug: check if field exists
|
||||||
|
fd := field(ecm, resultFieldName)
|
||||||
|
if fd == nil {
|
||||||
|
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug: log the actual field being set
|
||||||
|
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
|
||||||
|
|
||||||
|
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "exec_client_message", ecm)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func listFields(msg *dynamicpb.Message) []string {
|
||||||
|
var names []string
|
||||||
|
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
|
||||||
|
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Utilities ---
|
||||||
|
|
||||||
|
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
|
||||||
|
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
|
||||||
|
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
|
||||||
|
if len(jsonData) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var v interface{}
|
||||||
|
if err := json.Unmarshal(jsonData, &v); err != nil {
|
||||||
|
return jsonData // fallback to raw JSON if parsing fails
|
||||||
|
}
|
||||||
|
pbVal, err := structpb.NewValue(v)
|
||||||
|
if err != nil {
|
||||||
|
return jsonData // fallback
|
||||||
|
}
|
||||||
|
b, err := proto.Marshal(pbVal)
|
||||||
|
if err != nil {
|
||||||
|
return jsonData // fallback
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
|
||||||
|
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
|
||||||
|
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
|
||||||
|
val := &structpb.Value{}
|
||||||
|
if err := proto.Unmarshal(data, val); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return val.AsInterface(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sha256Sum(data []byte) []byte {
|
||||||
|
h := sha256.Sum256(data)
|
||||||
|
return h[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
var idCounter uint64
|
||||||
|
|
||||||
|
func generateId() string {
|
||||||
|
idCounter++
|
||||||
|
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
|
||||||
|
return hex.EncodeToString(h[:16])
|
||||||
|
}
|
||||||
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
|
||||||
|
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
|
||||||
|
package proto
|
||||||
|
|
||||||
|
// AgentClientMessage (msg 118) oneof "message"
|
||||||
|
const (
|
||||||
|
ACM_RunRequest = 1 // AgentRunRequest
|
||||||
|
ACM_ExecClientMessage = 2 // ExecClientMessage
|
||||||
|
ACM_KvClientMessage = 3 // KvClientMessage
|
||||||
|
ACM_ConversationAction = 4 // ConversationAction
|
||||||
|
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
|
||||||
|
ACM_InteractionResponse = 6 // InteractionResponse
|
||||||
|
ACM_ClientHeartbeat = 7 // ClientHeartbeat
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentServerMessage (msg 119) oneof "message"
|
||||||
|
const (
|
||||||
|
ASM_InteractionUpdate = 1 // InteractionUpdate
|
||||||
|
ASM_ExecServerMessage = 2 // ExecServerMessage
|
||||||
|
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
|
||||||
|
ASM_KvServerMessage = 4 // KvServerMessage
|
||||||
|
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
|
||||||
|
ASM_InteractionQuery = 7 // InteractionQuery
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentRunRequest (msg 91)
|
||||||
|
const (
|
||||||
|
ARR_ConversationState = 1 // ConversationStateStructure
|
||||||
|
ARR_Action = 2 // ConversationAction
|
||||||
|
ARR_ModelDetails = 3 // ModelDetails
|
||||||
|
ARR_McpTools = 4 // McpTools
|
||||||
|
ARR_ConversationId = 5 // string (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationStateStructure (msg 83)
|
||||||
|
const (
|
||||||
|
CSS_RootPromptMessagesJson = 1 // repeated bytes
|
||||||
|
CSS_TurnsOld = 2 // repeated bytes (deprecated)
|
||||||
|
CSS_Todos = 3 // repeated bytes
|
||||||
|
CSS_PendingToolCalls = 4 // repeated string
|
||||||
|
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
|
||||||
|
CSS_PreviousWorkspaceUris = 9 // repeated string
|
||||||
|
CSS_SelfSummaryCount = 17 // uint32
|
||||||
|
CSS_ReadPaths = 18 // repeated string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationAction (msg 54) oneof "action"
|
||||||
|
const (
|
||||||
|
CA_UserMessageAction = 1 // UserMessageAction
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserMessageAction (msg 55)
|
||||||
|
const (
|
||||||
|
UMA_UserMessage = 1 // UserMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserMessage (msg 63)
|
||||||
|
const (
|
||||||
|
UM_Text = 1 // string
|
||||||
|
UM_MessageId = 2 // string
|
||||||
|
UM_SelectedContext = 3 // SelectedContext (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// SelectedContext
|
||||||
|
const (
|
||||||
|
SC_SelectedImages = 1 // repeated SelectedImage
|
||||||
|
)
|
||||||
|
|
||||||
|
// SelectedImage
|
||||||
|
const (
|
||||||
|
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
|
||||||
|
SI_Uuid = 2 // string
|
||||||
|
SI_Path = 3 // string
|
||||||
|
SI_MimeType = 7 // string
|
||||||
|
SI_Data = 8 // bytes (oneof dataOrBlobId)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelDetails (msg 88)
|
||||||
|
const (
|
||||||
|
MD_ModelId = 1 // string
|
||||||
|
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
|
||||||
|
MD_DisplayModelId = 3 // string
|
||||||
|
MD_DisplayName = 4 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpTools (msg 307)
|
||||||
|
const (
|
||||||
|
MT_McpTools = 1 // repeated McpToolDefinition
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpToolDefinition (msg 306)
|
||||||
|
const (
|
||||||
|
MTD_Name = 1 // string
|
||||||
|
MTD_Description = 2 // string
|
||||||
|
MTD_InputSchema = 3 // bytes
|
||||||
|
MTD_ProviderIdentifier = 4 // string
|
||||||
|
MTD_ToolName = 5 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationTurnStructure (msg 70) oneof "turn"
|
||||||
|
const (
|
||||||
|
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentConversationTurnStructure (msg 72)
|
||||||
|
const (
|
||||||
|
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
|
||||||
|
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationStep (msg 53) oneof "message"
|
||||||
|
const (
|
||||||
|
CS_AssistantMessage = 1 // AssistantMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
// AssistantMessage
|
||||||
|
const (
|
||||||
|
AM_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Server-side message fields ---
|
||||||
|
|
||||||
|
// InteractionUpdate oneof "message"
|
||||||
|
const (
|
||||||
|
IU_TextDelta = 1 // TextDeltaUpdate
|
||||||
|
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
|
||||||
|
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
|
||||||
|
)
|
||||||
|
|
||||||
|
// TextDeltaUpdate (msg 92)
|
||||||
|
const (
|
||||||
|
TDU_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ThinkingDeltaUpdate (msg 97)
|
||||||
|
const (
|
||||||
|
TKD_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// KvServerMessage (msg 271)
|
||||||
|
const (
|
||||||
|
KSM_Id = 1 // uint32
|
||||||
|
KSM_GetBlobArgs = 2 // GetBlobArgs
|
||||||
|
KSM_SetBlobArgs = 3 // SetBlobArgs
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetBlobArgs (msg 267)
|
||||||
|
const (
|
||||||
|
GBA_BlobId = 1 // bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetBlobArgs (msg 269)
|
||||||
|
const (
|
||||||
|
SBA_BlobId = 1 // bytes
|
||||||
|
SBA_BlobData = 2 // bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
// KvClientMessage (msg 272)
|
||||||
|
const (
|
||||||
|
KCM_Id = 1 // uint32
|
||||||
|
KCM_GetBlobResult = 2 // GetBlobResult
|
||||||
|
KCM_SetBlobResult = 3 // SetBlobResult
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetBlobResult (msg 268)
|
||||||
|
const (
|
||||||
|
GBR_BlobData = 1 // bytes (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecServerMessage
|
||||||
|
const (
|
||||||
|
ESM_Id = 1 // uint32
|
||||||
|
ESM_ExecId = 15 // string
|
||||||
|
// oneof message:
|
||||||
|
ESM_ShellArgs = 2 // ShellArgs
|
||||||
|
ESM_WriteArgs = 3 // WriteArgs
|
||||||
|
ESM_DeleteArgs = 4 // DeleteArgs
|
||||||
|
ESM_GrepArgs = 5 // GrepArgs
|
||||||
|
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
|
||||||
|
ESM_LsArgs = 8 // LsArgs
|
||||||
|
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
|
||||||
|
ESM_RequestContextArgs = 10 // RequestContextArgs
|
||||||
|
ESM_McpArgs = 11 // McpArgs
|
||||||
|
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
|
||||||
|
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
|
||||||
|
ESM_FetchArgs = 20 // FetchArgs
|
||||||
|
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecClientMessage
|
||||||
|
const (
|
||||||
|
ECM_Id = 1 // uint32
|
||||||
|
ECM_ExecId = 15 // string
|
||||||
|
// oneof message (mirrors server fields):
|
||||||
|
ECM_ShellResult = 2
|
||||||
|
ECM_WriteResult = 3
|
||||||
|
ECM_DeleteResult = 4
|
||||||
|
ECM_GrepResult = 5
|
||||||
|
ECM_ReadResult = 7
|
||||||
|
ECM_LsResult = 8
|
||||||
|
ECM_DiagnosticsResult = 9
|
||||||
|
ECM_RequestContextResult = 10
|
||||||
|
ECM_McpResult = 11
|
||||||
|
ECM_ShellStream = 14
|
||||||
|
ECM_BackgroundShellSpawnRes = 16
|
||||||
|
ECM_FetchResult = 20
|
||||||
|
ECM_WriteShellStdinResult = 23
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpArgs
|
||||||
|
const (
|
||||||
|
MCA_Name = 1 // string
|
||||||
|
MCA_Args = 2 // map<string, bytes>
|
||||||
|
MCA_ToolCallId = 3 // string
|
||||||
|
MCA_ProviderIdentifier = 4 // string
|
||||||
|
MCA_ToolName = 5 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContextResult oneof "result"
|
||||||
|
const (
|
||||||
|
RCR_Success = 1 // RequestContextSuccess
|
||||||
|
RCR_Error = 2 // RequestContextError
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContextSuccess (msg 337)
|
||||||
|
const (
|
||||||
|
RCS_RequestContext = 1 // RequestContext
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContext
|
||||||
|
const (
|
||||||
|
RC_Rules = 2 // repeated CursorRule
|
||||||
|
RC_Tools = 7 // repeated McpToolDefinition
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpResult oneof "result"
|
||||||
|
const (
|
||||||
|
MCR_Success = 1 // McpSuccess
|
||||||
|
MCR_Error = 2 // McpError
|
||||||
|
MCR_Rejected = 3 // McpRejected
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpSuccess (msg 290)
|
||||||
|
const (
|
||||||
|
MCS_Content = 1 // repeated McpToolResultContentItem
|
||||||
|
MCS_IsError = 2 // bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpToolResultContentItem oneof "content"
|
||||||
|
const (
|
||||||
|
MTRCI_Text = 1 // McpTextContent
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpTextContent (msg 287)
|
||||||
|
const (
|
||||||
|
MTC_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpError (msg 291)
|
||||||
|
const (
|
||||||
|
MCE_Error = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Rejection messages ---
|
||||||
|
|
||||||
|
// ReadRejected: path=1, reason=2
|
||||||
|
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
|
||||||
|
// WriteRejected: path=1, reason=2
|
||||||
|
// DeleteRejected: path=1, reason=2
|
||||||
|
// LsRejected: path=1, reason=2
|
||||||
|
// GrepError: error=1
|
||||||
|
// FetchError: url=1, error=2
|
||||||
|
// WriteShellStdinError: error=1
|
||||||
|
|
||||||
|
// ReadResult oneof: success=1, error=2, rejected=3
|
||||||
|
// ShellResult oneof: success=1 (+ various), rejected=?
|
||||||
|
// The TS code uses specific result field numbers from the oneof:
|
||||||
|
const (
|
||||||
|
RR_Rejected = 3 // ReadResult.rejected
|
||||||
|
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
|
||||||
|
WR_Rejected = 5 // WriteResult.rejected
|
||||||
|
DR_Rejected = 3 // DeleteResult.rejected
|
||||||
|
LR_Rejected = 3 // LsResult.rejected
|
||||||
|
GR_Error = 2 // GrepResult.error
|
||||||
|
FR_Error = 2 // FetchResult.error
|
||||||
|
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
|
||||||
|
WSSR_Error = 2 // WriteShellStdinResult.error
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Rejection struct fields ---
|
||||||
|
const (
|
||||||
|
REJ_Path = 1
|
||||||
|
REJ_Reason = 2
|
||||||
|
SREJ_Command = 1
|
||||||
|
SREJ_WorkingDir = 2
|
||||||
|
SREJ_Reason = 3
|
||||||
|
SREJ_IsReadonly = 4
|
||||||
|
GERR_Error = 1
|
||||||
|
FERR_Url = 1
|
||||||
|
FERR_Error = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadArgs
|
||||||
|
const (
|
||||||
|
RA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteArgs
|
||||||
|
const (
|
||||||
|
WA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeleteArgs
|
||||||
|
const (
|
||||||
|
DA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// LsArgs
|
||||||
|
const (
|
||||||
|
LA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShellArgs
|
||||||
|
const (
|
||||||
|
SHA_Command = 1 // string
|
||||||
|
SHA_WorkingDirectory = 2 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// FetchArgs
|
||||||
|
const (
|
||||||
|
FA_Url = 1 // string
|
||||||
|
)
|
||||||
313
internal/auth/cursor/proto/h2stream.go
Normal file
313
internal/auth/cursor/proto/h2stream.go
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/hpack"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultInitialWindowSize = 65535 // HTTP/2 default
|
||||||
|
maxFramePayload = 16384 // HTTP/2 default max frame size
|
||||||
|
)
|
||||||
|
|
||||||
|
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
||||||
|
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
||||||
|
type H2Stream struct {
|
||||||
|
framer *http2.Framer
|
||||||
|
conn net.Conn
|
||||||
|
streamID uint32
|
||||||
|
mu sync.Mutex
|
||||||
|
id string // unique identifier for debugging
|
||||||
|
frameNum int64 // sequential frame counter for debugging
|
||||||
|
|
||||||
|
dataCh chan []byte
|
||||||
|
doneCh chan struct{}
|
||||||
|
err error
|
||||||
|
|
||||||
|
// Send-side flow control
|
||||||
|
sendWindow int32 // available bytes we can send on this stream
|
||||||
|
connWindow int32 // available bytes on the connection level
|
||||||
|
windowCond *sync.Cond // signaled when window is updated
|
||||||
|
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the unique identifier for this stream (for logging).
|
||||||
|
func (s *H2Stream) ID() string { return s.id }
|
||||||
|
|
||||||
|
// FrameNum returns the current frame number for debugging.
|
||||||
|
func (s *H2Stream) FrameNum() int64 {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.frameNum
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
|
||||||
|
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||||
|
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
|
||||||
|
}
|
||||||
|
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: server did not negotiate h2")
|
||||||
|
}
|
||||||
|
|
||||||
|
framer := http2.NewFramer(tlsConn, tlsConn)
|
||||||
|
|
||||||
|
// Client connection preface
|
||||||
|
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial SETTINGS (tell server how much WE can receive)
|
||||||
|
if err := framer.WriteSettings(
|
||||||
|
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
||||||
|
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
||||||
|
); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection-level window update (for receiving)
|
||||||
|
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
||||||
|
// Track server's initial window size (how much WE can send)
|
||||||
|
serverInitialWindowSize := int32(defaultInitialWindowSize)
|
||||||
|
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
f, err := framer.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
|
||||||
|
}
|
||||||
|
switch sf := f.(type) {
|
||||||
|
case *http2.SettingsFrame:
|
||||||
|
if !sf.IsAck() {
|
||||||
|
sf.ForeachSetting(func(s http2.Setting) error {
|
||||||
|
if s.ID == http2.SettingInitialWindowSize {
|
||||||
|
serverInitialWindowSize = int32(s.Val)
|
||||||
|
log.Debugf("h2: server initial window size: %d", s.Val)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
framer.WriteSettingsAck()
|
||||||
|
} else {
|
||||||
|
goto handshakeDone
|
||||||
|
}
|
||||||
|
case *http2.WindowUpdateFrame:
|
||||||
|
if sf.StreamID == 0 {
|
||||||
|
connWindowSize += int32(sf.Increment)
|
||||||
|
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// unexpected but continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handshakeDone:
|
||||||
|
|
||||||
|
// Build HEADERS
|
||||||
|
streamID := uint32(1)
|
||||||
|
var hdrBuf []byte
|
||||||
|
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
|
||||||
|
if p, ok := headers[":path"]; ok {
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
|
||||||
|
}
|
||||||
|
for k, v := range headers {
|
||||||
|
if len(k) > 0 && k[0] == ':' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||||
|
StreamID: streamID,
|
||||||
|
BlockFragment: hdrBuf,
|
||||||
|
EndStream: false,
|
||||||
|
EndHeaders: true,
|
||||||
|
}); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: headers write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &H2Stream{
|
||||||
|
framer: framer,
|
||||||
|
conn: tlsConn,
|
||||||
|
streamID: streamID,
|
||||||
|
dataCh: make(chan []byte, 256),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||||
|
frameNum: 0,
|
||||||
|
sendWindow: serverInitialWindowSize,
|
||||||
|
connWindow: connWindowSize,
|
||||||
|
}
|
||||||
|
s.windowCond = sync.NewCond(&s.windowMu)
|
||||||
|
go s.readLoop()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write sends a DATA frame on the stream, respecting flow control.
|
||||||
|
func (s *H2Stream) Write(data []byte) error {
|
||||||
|
for len(data) > 0 {
|
||||||
|
chunk := data
|
||||||
|
if len(chunk) > maxFramePayload {
|
||||||
|
chunk = data[:maxFramePayload]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for flow control window
|
||||||
|
s.windowMu.Lock()
|
||||||
|
for s.sendWindow <= 0 || s.connWindow <= 0 {
|
||||||
|
s.windowCond.Wait()
|
||||||
|
}
|
||||||
|
// Limit chunk to available window
|
||||||
|
allowed := int(s.sendWindow)
|
||||||
|
if int(s.connWindow) < allowed {
|
||||||
|
allowed = int(s.connWindow)
|
||||||
|
}
|
||||||
|
if len(chunk) > allowed {
|
||||||
|
chunk = chunk[:allowed]
|
||||||
|
}
|
||||||
|
s.sendWindow -= int32(len(chunk))
|
||||||
|
s.connWindow -= int32(len(chunk))
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
err := s.framer.WriteData(s.streamID, false, chunk)
|
||||||
|
s.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data = data[len(chunk):]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data returns the channel of received data chunks.
|
||||||
|
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
|
||||||
|
|
||||||
|
// Done returns a channel closed when the stream ends.
|
||||||
|
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
|
||||||
|
|
||||||
|
// Err returns the error (if any) that caused the stream to close.
|
||||||
|
// Returns nil for a clean shutdown (EOF / StreamEnded).
|
||||||
|
func (s *H2Stream) Err() error { return s.err }
|
||||||
|
|
||||||
|
// Close tears down the connection.
|
||||||
|
func (s *H2Stream) Close() {
|
||||||
|
s.conn.Close()
|
||||||
|
// Unblock any writers waiting on flow control
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *H2Stream) readLoop() {
|
||||||
|
defer close(s.doneCh)
|
||||||
|
defer close(s.dataCh)
|
||||||
|
|
||||||
|
for {
|
||||||
|
f, err := s.framer.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
s.err = err
|
||||||
|
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment frame counter
|
||||||
|
s.mu.Lock()
|
||||||
|
s.frameNum++
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
switch frame := f.(type) {
|
||||||
|
case *http2.DataFrame:
|
||||||
|
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
||||||
|
cp := make([]byte, len(frame.Data()))
|
||||||
|
copy(cp, frame.Data())
|
||||||
|
s.dataCh <- cp
|
||||||
|
|
||||||
|
// Flow control: send WINDOW_UPDATE for received data
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
||||||
|
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
if frame.StreamEnded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.HeadersFrame:
|
||||||
|
if frame.StreamEnded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.RSTStreamFrame:
|
||||||
|
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
||||||
|
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
|
||||||
|
return
|
||||||
|
|
||||||
|
case *http2.GoAwayFrame:
|
||||||
|
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
||||||
|
return
|
||||||
|
|
||||||
|
case *http2.PingFrame:
|
||||||
|
if !frame.IsAck() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WritePing(true, frame.Data)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.SettingsFrame:
|
||||||
|
if !frame.IsAck() {
|
||||||
|
// Check for window size changes
|
||||||
|
frame.ForeachSetting(func(setting http2.Setting) error {
|
||||||
|
if setting.ID == http2.SettingInitialWindowSize {
|
||||||
|
s.windowMu.Lock()
|
||||||
|
delta := int32(setting.Val) - s.sendWindow
|
||||||
|
s.sendWindow += delta
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WriteSettingsAck()
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.WindowUpdateFrame:
|
||||||
|
// Update send-side flow control window
|
||||||
|
s.windowMu.Lock()
|
||||||
|
if frame.StreamID == 0 {
|
||||||
|
s.connWindow += int32(frame.Increment)
|
||||||
|
} else if frame.StreamID == s.streamID {
|
||||||
|
s.sendWindow += int32(frame.Increment)
|
||||||
|
}
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sliceWriter struct{ buf *[]byte }
|
||||||
|
|
||||||
|
func (w *sliceWriter) Write(p []byte) (int, error) {
|
||||||
|
*w.buf = append(*w.buf, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
@@ -10,9 +10,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
@@ -20,9 +18,9 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
}
|
}
|
||||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||||
|
|
||||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
if errBuild != nil {
|
||||||
if err == nil {
|
log.Errorf("%v", errBuild)
|
||||||
var transport *http.Transport
|
} else if transport != nil {
|
||||||
if proxyURL.Scheme == "socks5" {
|
proxyClient := &http.Client{Transport: transport}
|
||||||
// Handle SOCKS5 proxy.
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
auth := &proxy.Auth{User: username, Password: password}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
|
||||||
}
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
// Handle HTTP/HTTPS proxy.
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
|
|
||||||
if transport != nil {
|
|
||||||
proxyClient := &http.Client{Transport: transport}
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
// Configure the OAuth2 client.
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: ClientID,
|
ClientID: ClientID,
|
||||||
@@ -327,6 +305,9 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
defer manualPromptTimer.Stop()
|
defer manualPromptTimer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var manualInputCh <-chan string
|
||||||
|
var manualInputErrCh <-chan error
|
||||||
|
|
||||||
waitForCallback:
|
waitForCallback:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -348,13 +329,14 @@ waitForCallback:
|
|||||||
return nil, err
|
return nil, err
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||||
if err != nil {
|
continue
|
||||||
return nil, err
|
case input := <-manualInputCh:
|
||||||
}
|
manualInputCh = nil
|
||||||
parsed, err := misc.ParseOAuthCallback(input)
|
manualInputErrCh = nil
|
||||||
if err != nil {
|
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||||
return nil, err
|
if errParse != nil {
|
||||||
|
return nil, errParse
|
||||||
}
|
}
|
||||||
if parsed == nil {
|
if parsed == nil {
|
||||||
continue
|
continue
|
||||||
@@ -367,6 +349,8 @@ waitForCallback:
|
|||||||
}
|
}
|
||||||
authCode = parsed.Code
|
authCode = parsed.Code
|
||||||
break waitForCallback
|
break waitForCallback
|
||||||
|
case errManual := <-manualInputErrCh:
|
||||||
|
return nil, errManual
|
||||||
case <-timeoutTimer.C:
|
case <-timeoutTimer.C:
|
||||||
return nil, fmt.Errorf("oauth flow timed out")
|
return nil, fmt.Errorf("oauth flow timed out")
|
||||||
}
|
}
|
||||||
|
|||||||
492
internal/auth/gitlab/gitlab.go
Normal file
492
internal/auth/gitlab/gitlab.go
Normal file
@@ -0,0 +1,492 @@
|
|||||||
|
package gitlab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultBaseURL = "https://gitlab.com"
|
||||||
|
DefaultCallbackPort = 17171
|
||||||
|
defaultOAuthScope = "api read_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PKCECodes struct {
|
||||||
|
CodeVerifier string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthResult struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthServer struct {
|
||||||
|
server *http.Server
|
||||||
|
port int
|
||||||
|
resultChan chan *OAuthResult
|
||||||
|
errorChan chan error
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
PublicEmail string `json:"public_email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PersonalAccessTokenSelf struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Scopes []string `json:"scopes"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelDetails struct {
|
||||||
|
ModelProvider string `json:"model_provider"`
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DirectAccessResponse struct {
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
Headers map[string]string `json:"headers"`
|
||||||
|
ModelDetails *ModelDetails `json:"model_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiscoveredModel struct {
|
||||||
|
ModelProvider string
|
||||||
|
ModelName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthClient(cfg *config.Config) *AuthClient {
|
||||||
|
client := &http.Client{}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &AuthClient{httpClient: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeBaseURL(raw string) string {
|
||||||
|
value := strings.TrimSpace(raw)
|
||||||
|
if value == "" {
|
||||||
|
return DefaultBaseURL
|
||||||
|
}
|
||||||
|
if !strings.Contains(value, "://") {
|
||||||
|
value = "https://" + value
|
||||||
|
}
|
||||||
|
value = strings.TrimRight(value, "/")
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
|
||||||
|
if token == nil {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
|
||||||
|
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
|
||||||
|
}
|
||||||
|
if token.ExpiresIn > 0 {
|
||||||
|
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||||
|
}
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||||
|
verifierBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(verifierBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
|
||||||
|
}
|
||||||
|
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||||
|
sum := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||||
|
return &PKCECodes{
|
||||||
|
CodeVerifier: verifier,
|
||||||
|
CodeChallenge: challenge,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthServer(port int) *OAuthServer {
|
||||||
|
return &OAuthServer{
|
||||||
|
port: port,
|
||||||
|
resultChan: make(chan *OAuthResult, 1),
|
||||||
|
errorChan: make(chan error, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) Start() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.running {
|
||||||
|
return fmt.Errorf("gitlab oauth server already running")
|
||||||
|
}
|
||||||
|
if !s.isPortAvailable() {
|
||||||
|
return fmt.Errorf("port %d is already in use", s.port)
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||||
|
|
||||||
|
s.server = &http.Server{
|
||||||
|
Addr: fmt.Sprintf(":%d", s.port),
|
||||||
|
Handler: mux,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
s.running = true
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
s.errorChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if !s.running || s.server == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
s.running = false
|
||||||
|
s.server = nil
|
||||||
|
}()
|
||||||
|
return s.server.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||||
|
select {
|
||||||
|
case result := <-s.resultChan:
|
||||||
|
return result, nil
|
||||||
|
case err := <-s.errorChan:
|
||||||
|
return nil, err
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
query := r.URL.Query()
|
||||||
|
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||||
|
s.sendResult(&OAuthResult{Error: errParam})
|
||||||
|
http.Error(w, errParam, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := strings.TrimSpace(query.Get("code"))
|
||||||
|
state := strings.TrimSpace(query.Get("state"))
|
||||||
|
if code == "" || state == "" {
|
||||||
|
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
|
||||||
|
http.Error(w, "missing code or state", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||||
|
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||||
|
select {
|
||||||
|
case s.resultChan <- result:
|
||||||
|
default:
|
||||||
|
log.Debug("gitlab oauth result channel full, dropping callback result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) isPortAvailable() bool {
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_ = listener.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func RedirectURL(port int) string {
|
||||||
|
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
|
||||||
|
if pkce == nil {
|
||||||
|
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(clientID) == "" {
|
||||||
|
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
|
||||||
|
}
|
||||||
|
baseURL = NormalizeBaseURL(baseURL)
|
||||||
|
params := url.Values{
|
||||||
|
"client_id": {strings.TrimSpace(clientID)},
|
||||||
|
"response_type": {"code"},
|
||||||
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
|
"scope": {defaultOAuthScope},
|
||||||
|
"state": {strings.TrimSpace(state)},
|
||||||
|
"code_challenge": {pkce.CodeChallenge},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"client_id": {strings.TrimSpace(clientID)},
|
||||||
|
"code": {strings.TrimSpace(code)},
|
||||||
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
|
"code_verifier": {strings.TrimSpace(codeVerifier)},
|
||||||
|
}
|
||||||
|
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||||
|
form.Set("client_secret", secret)
|
||||||
|
}
|
||||||
|
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||||
|
}
|
||||||
|
if clientID = strings.TrimSpace(clientID); clientID != "" {
|
||||||
|
form.Set("client_id", clientID)
|
||||||
|
}
|
||||||
|
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||||
|
form.Set("client_secret", secret)
|
||||||
|
}
|
||||||
|
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
var token TokenResponse
|
||||||
|
if err := json.Unmarshal(body, &token); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
if err := json.Unmarshal(body, &user); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var pat PersonalAccessTokenSelf
|
||||||
|
if err := json.Unmarshal(body, &pat); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &pat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var direct DirectAccessResponse
|
||||||
|
if err := json.Unmarshal(body, &direct); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
if direct.Headers == nil {
|
||||||
|
direct.Headers = make(map[string]string)
|
||||||
|
}
|
||||||
|
return &direct, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]DiscoveredModel, 0, 4)
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
appendModel := func(provider, name string) {
|
||||||
|
provider = strings.TrimSpace(provider)
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
models = append(models, DiscoveredModel{
|
||||||
|
ModelProvider: provider,
|
||||||
|
ModelName: name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw, ok := metadata["model_details"]; ok {
|
||||||
|
appendDiscoveredModels(raw, appendModel)
|
||||||
|
}
|
||||||
|
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
|
||||||
|
|
||||||
|
for _, key := range []string{"models", "supported_models", "discovered_models"} {
|
||||||
|
if raw, ok := metadata[key]; ok {
|
||||||
|
appendDiscoveredModels(raw, appendModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
|
||||||
|
switch typed := raw.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
|
||||||
|
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
|
||||||
|
if nested, ok := typed["models"]; ok {
|
||||||
|
appendDiscoveredModels(nested, appendModel)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, item := range typed {
|
||||||
|
appendDiscoveredModels(item, appendModel)
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
for _, item := range typed {
|
||||||
|
appendModel("", item)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
appendModel("", typed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringValue(raw any) string {
|
||||||
|
switch typed := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(typed)
|
||||||
|
case fmt.Stringer:
|
||||||
|
return strings.TrimSpace(typed.String())
|
||||||
|
case json.Number:
|
||||||
|
return typed.String()
|
||||||
|
case int:
|
||||||
|
return strconv.Itoa(typed)
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(typed, 10)
|
||||||
|
case float64:
|
||||||
|
return strconv.FormatInt(int64(typed), 10)
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
138
internal/auth/gitlab/gitlab_test.go
Normal file
138
internal/auth/gitlab/gitlab_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package gitlab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthClientGenerateAuthURLIncludesPKCE(t *testing.T) {
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
pkce, err := GeneratePKCECodes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCECodes() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawURL, err := client.GenerateAuthURL("https://gitlab.example.com", "client-id", RedirectURL(17171), "state-123", pkce)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateAuthURL() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse(authURL) error = %v", err)
|
||||||
|
}
|
||||||
|
if got := parsed.Path; got != "/oauth/authorize" {
|
||||||
|
t.Fatalf("expected /oauth/authorize path, got %q", got)
|
||||||
|
}
|
||||||
|
query := parsed.Query()
|
||||||
|
if got := query.Get("client_id"); got != "client-id" {
|
||||||
|
t.Fatalf("expected client_id, got %q", got)
|
||||||
|
}
|
||||||
|
if got := query.Get("scope"); got != defaultOAuthScope {
|
||||||
|
t.Fatalf("expected scope %q, got %q", defaultOAuthScope, got)
|
||||||
|
}
|
||||||
|
if got := query.Get("code_challenge_method"); got != "S256" {
|
||||||
|
t.Fatalf("expected PKCE method S256, got %q", got)
|
||||||
|
}
|
||||||
|
if got := query.Get("code_challenge"); got == "" {
|
||||||
|
t.Fatal("expected non-empty code_challenge")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthClientExchangeCodeForTokens(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/oauth/token" {
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("ParseForm() error = %v", err)
|
||||||
|
}
|
||||||
|
if got := r.Form.Get("grant_type"); got != "authorization_code" {
|
||||||
|
t.Fatalf("expected authorization_code grant, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.Form.Get("code_verifier"); got != "verifier-123" {
|
||||||
|
t.Fatalf("expected code_verifier, got %q", got)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"access_token": "oauth-access",
|
||||||
|
"refresh_token": "oauth-refresh",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "api read_user",
|
||||||
|
"created_at": 1710000000,
|
||||||
|
"expires_in": 3600,
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
token, err := client.ExchangeCodeForTokens(context.Background(), srv.URL, "client-id", "client-secret", RedirectURL(17171), "auth-code", "verifier-123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExchangeCodeForTokens() error = %v", err)
|
||||||
|
}
|
||||||
|
if token.AccessToken != "oauth-access" {
|
||||||
|
t.Fatalf("expected access token, got %q", token.AccessToken)
|
||||||
|
}
|
||||||
|
if token.RefreshToken != "oauth-refresh" {
|
||||||
|
t.Fatalf("expected refresh token, got %q", token.RefreshToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractDiscoveredModels(t *testing.T) {
|
||||||
|
models := ExtractDiscoveredModels(map[string]any{
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
"supported_models": []any{
|
||||||
|
map[string]any{"model_provider": "openai", "model_name": "gpt-4.1"},
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(models) != 2 {
|
||||||
|
t.Fatalf("expected 2 unique models, got %d", len(models))
|
||||||
|
}
|
||||||
|
if models[0].ModelName != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("unexpected first model %q", models[0].ModelName)
|
||||||
|
}
|
||||||
|
if models[1].ModelName != "gpt-4.1" {
|
||||||
|
t.Fatalf("unexpected second model %q", models[1].ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchDirectAccessDecodesModelDetails(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Authorization"); !strings.Contains(got, "token-123") {
|
||||||
|
t.Fatalf("expected bearer token, got %q", got)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"base_url": "https://cloud.gitlab.example.com",
|
||||||
|
"token": "gateway-token",
|
||||||
|
"expires_at": 1710003600,
|
||||||
|
"headers": map[string]string{
|
||||||
|
"X-Gitlab-Realm": "saas",
|
||||||
|
},
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
direct, err := client.FetchDirectAccess(context.Background(), srv.URL, "token-123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FetchDirectAccess() error = %v", err)
|
||||||
|
}
|
||||||
|
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("expected model details, got %+v", direct.ModelDetails)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,8 +5,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// newAuthManager creates a new authentication manager instance with all supported
|
// newAuthManager creates a new authentication manager instance with all supported
|
||||||
// authenticators and a file-based token store. It initializes authenticators for
|
// authenticators and a file-based token store.
|
||||||
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
|
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||||
@@ -23,6 +22,9 @@ func newAuthManager() *sdkAuth.Manager {
|
|||||||
sdkAuth.NewKiroAuthenticator(),
|
sdkAuth.NewKiroAuthenticator(),
|
||||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||||
sdkAuth.NewKiloAuthenticator(),
|
sdkAuth.NewKiloAuthenticator(),
|
||||||
|
sdkAuth.NewGitLabAuthenticator(),
|
||||||
|
sdkAuth.NewCodeBuddyAuthenticator(),
|
||||||
|
sdkAuth.NewCursorAuthenticator(),
|
||||||
)
|
)
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|||||||
43
internal/cmd/codebuddy_login.go
Normal file
43
internal/cmd/codebuddy_login.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens.
|
||||||
|
// It initiates the OAuth authentication, displays the user code for the user to enter
|
||||||
|
// at the CodeBuddy verification URL, and waits for authorization before saving the tokens.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration containing proxy and auth directory settings
|
||||||
|
// - options: Login options including browser behavior settings
|
||||||
|
func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("CodeBuddy authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("CodeBuddy authentication successful!")
|
||||||
|
}
|
||||||
37
internal/cmd/cursor_login.go
Normal file
37
internal/cmd/cursor_login.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
|
||||||
|
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Cursor authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
log.Infof("Authentication saved to %s", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
log.Infof("Authenticated as %s", record.Label)
|
||||||
|
}
|
||||||
|
log.Info("Cursor authentication successful!")
|
||||||
|
}
|
||||||
69
internal/cmd/gitlab_login.go
Normal file
69
internal/cmd/gitlab_login.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
CallbackPort: options.CallbackPort,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"login_mode": "oauth",
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("GitLab Duo authentication successful!")
|
||||||
|
}
|
||||||
|
|
||||||
|
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"login_mode": "pat",
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("GitLab Duo PAT authentication successful!")
|
||||||
|
}
|
||||||
55
internal/config/claude_header_defaults_test.go
Normal file
55
internal/config/claude_header_defaults_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfigOptional_ClaudeHeaderDefaults(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
configYAML := []byte(`
|
||||||
|
claude-header-defaults:
|
||||||
|
user-agent: " claude-cli/2.1.70 (external, cli) "
|
||||||
|
package-version: " 0.80.0 "
|
||||||
|
runtime-version: " v24.5.0 "
|
||||||
|
os: " MacOS "
|
||||||
|
arch: " arm64 "
|
||||||
|
timeout: " 900 "
|
||||||
|
stabilize-device-profile: false
|
||||||
|
`)
|
||||||
|
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := LoadConfigOptional(configPath, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.UserAgent; got != "claude-cli/2.1.70 (external, cli)" {
|
||||||
|
t.Fatalf("UserAgent = %q, want %q", got, "claude-cli/2.1.70 (external, cli)")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.PackageVersion; got != "0.80.0" {
|
||||||
|
t.Fatalf("PackageVersion = %q, want %q", got, "0.80.0")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.RuntimeVersion; got != "v24.5.0" {
|
||||||
|
t.Fatalf("RuntimeVersion = %q, want %q", got, "v24.5.0")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.OS; got != "MacOS" {
|
||||||
|
t.Fatalf("OS = %q, want %q", got, "MacOS")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.Arch; got != "arm64" {
|
||||||
|
t.Fatalf("Arch = %q, want %q", got, "arm64")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.Timeout; got != "900" {
|
||||||
|
t.Fatalf("Timeout = %q, want %q", got, "900")
|
||||||
|
}
|
||||||
|
if cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||||
|
t.Fatal("StabilizeDeviceProfile = nil, want non-nil")
|
||||||
|
}
|
||||||
|
if got := *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile; got {
|
||||||
|
t.Fatalf("StabilizeDeviceProfile = %v, want false", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
configYAML := []byte(`
|
||||||
|
codex-header-defaults:
|
||||||
|
user-agent: " my-codex-client/1.0 "
|
||||||
|
beta-features: " feature-a,feature-b "
|
||||||
|
`)
|
||||||
|
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := LoadConfigOptional(configPath, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
|
||||||
|
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
|
||||||
|
}
|
||||||
|
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
|
||||||
|
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@@ -101,6 +102,10 @@ type Config struct {
|
|||||||
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
||||||
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
||||||
|
|
||||||
|
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
|
||||||
|
// These are used only when the client does not send its own headers.
|
||||||
|
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
|
||||||
|
|
||||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||||
|
|
||||||
@@ -141,13 +146,27 @@ type Config struct {
|
|||||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
// ClaudeHeaderDefaults configures default header values injected into Claude API requests.
|
||||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
// In legacy mode, UserAgent/PackageVersion/RuntimeVersion/Timeout act as fallbacks when
|
||||||
|
// the client omits them, while OS/Arch remain runtime-derived. When stabilized device
|
||||||
|
// profiles are enabled, OS/Arch become the pinned platform baseline, while
|
||||||
|
// UserAgent/PackageVersion/RuntimeVersion seed the upgradeable software fingerprint.
|
||||||
type ClaudeHeaderDefaults struct {
|
type ClaudeHeaderDefaults struct {
|
||||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||||
Timeout string `yaml:"timeout" json:"timeout"`
|
OS string `yaml:"os" json:"os"`
|
||||||
|
Arch string `yaml:"arch" json:"arch"`
|
||||||
|
Timeout string `yaml:"timeout" json:"timeout"`
|
||||||
|
StabilizeDeviceProfile *bool `yaml:"stabilize-device-profile,omitempty" json:"stabilize-device-profile,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexHeaderDefaults configures fallback header values injected into Codex
|
||||||
|
// model requests for OAuth/file-backed auth when the client omits them.
|
||||||
|
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
|
||||||
|
type CodexHeaderDefaults struct {
|
||||||
|
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||||
|
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSConfig holds HTTPS server settings.
|
// TLSConfig holds HTTPS server settings.
|
||||||
@@ -176,6 +195,9 @@ type RemoteManagement struct {
|
|||||||
SecretKey string `yaml:"secret-key"`
|
SecretKey string `yaml:"secret-key"`
|
||||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||||
|
// DisableAutoUpdatePanel disables automatic periodic background updates of the management panel asset from GitHub.
|
||||||
|
// When false (the default), the background updater remains enabled; when true, the panel is only downloaded on first access if missing.
|
||||||
|
DisableAutoUpdatePanel bool `yaml:"disable-auto-update-panel"`
|
||||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||||
@@ -556,6 +578,10 @@ type OpenAICompatibilityModel struct {
|
|||||||
|
|
||||||
// Alias is the model name alias that clients will use to reference this model.
|
// Alias is the model name alias that clients will use to reference this model.
|
||||||
Alias string `yaml:"alias" json:"alias"`
|
Alias string `yaml:"alias" json:"alias"`
|
||||||
|
|
||||||
|
// Thinking configures the thinking/reasoning capability for this model.
|
||||||
|
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
|
||||||
|
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
||||||
@@ -673,12 +699,18 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||||
cfg.SanitizeGeminiKeys()
|
cfg.SanitizeGeminiKeys()
|
||||||
|
|
||||||
// Sanitize Vertex-compatible API keys: drop entries without base-url
|
// Sanitize Vertex-compatible API keys.
|
||||||
cfg.SanitizeVertexCompatKeys()
|
cfg.SanitizeVertexCompatKeys()
|
||||||
|
|
||||||
// Sanitize Codex keys: drop entries without base-url
|
// Sanitize Codex keys: drop entries without base-url
|
||||||
cfg.SanitizeCodexKeys()
|
cfg.SanitizeCodexKeys()
|
||||||
|
|
||||||
|
// Sanitize Codex header defaults.
|
||||||
|
cfg.SanitizeCodexHeaderDefaults()
|
||||||
|
|
||||||
|
// Sanitize Claude header defaults.
|
||||||
|
cfg.SanitizeClaudeHeaderDefaults()
|
||||||
|
|
||||||
// Sanitize Claude key headers
|
// Sanitize Claude key headers
|
||||||
cfg.SanitizeClaudeKeys()
|
cfg.SanitizeClaudeKeys()
|
||||||
|
|
||||||
@@ -771,6 +803,30 @@ func payloadRawString(value any) ([]byte, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
|
||||||
|
// configured Codex header fallback values.
|
||||||
|
func (cfg *Config) SanitizeCodexHeaderDefaults() {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
|
||||||
|
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeClaudeHeaderDefaults trims surrounding whitespace from the
|
||||||
|
// configured Claude fingerprint baseline values.
|
||||||
|
func (cfg *Config) SanitizeClaudeHeaderDefaults() {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.ClaudeHeaderDefaults.UserAgent = strings.TrimSpace(cfg.ClaudeHeaderDefaults.UserAgent)
|
||||||
|
cfg.ClaudeHeaderDefaults.PackageVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.PackageVersion)
|
||||||
|
cfg.ClaudeHeaderDefaults.RuntimeVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.RuntimeVersion)
|
||||||
|
cfg.ClaudeHeaderDefaults.OS = strings.TrimSpace(cfg.ClaudeHeaderDefaults.OS)
|
||||||
|
cfg.ClaudeHeaderDefaults.Arch = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Arch)
|
||||||
|
cfg.ClaudeHeaderDefaults.Timeout = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Timeout)
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ type VertexCompatKey struct {
|
|||||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
// BaseURL optionally overrides the Vertex-compatible API endpoint.
|
||||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
// When empty, requests fall back to the default Vertex API base URL.
|
||||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||||
|
|
||||||
// ProxyURL optionally overrides the global proxy for this API key.
|
// ProxyURL optionally overrides the global proxy for this API key.
|
||||||
@@ -71,10 +71,6 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
|||||||
}
|
}
|
||||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||||
if entry.BaseURL == "" {
|
|
||||||
// BaseURL is required for Vertex API key entries
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||||
managementSyncMinInterval = 30 * time.Second
|
managementSyncMinInterval = 30 * time.Second
|
||||||
updateCheckInterval = 3 * time.Hour
|
updateCheckInterval = 3 * time.Hour
|
||||||
|
maxAssetDownloadSize = 50 << 20 // 10 MB safety limit for management asset downloads
|
||||||
)
|
)
|
||||||
|
|
||||||
// ManagementFileName exposes the control panel asset filename.
|
// ManagementFileName exposes the control panel asset filename.
|
||||||
@@ -88,6 +89,10 @@ func runAutoUpdater(ctx context.Context) {
|
|||||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if cfg.RemoteManagement.DisableAutoUpdatePanel {
|
||||||
|
log.Debug("management asset auto-updater skipped: disable-auto-update-panel is enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configPath, _ := schedulerConfigPath.Load().(string)
|
configPath, _ := schedulerConfigPath.Load().(string)
|
||||||
staticDir := StaticDir(configPath)
|
staticDir := StaticDir(configPath)
|
||||||
@@ -259,7 +264,8 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
|||||||
}
|
}
|
||||||
|
|
||||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
log.Errorf("management asset digest mismatch: expected %s got %s — aborting update for safety", remoteHash, downloadedHash)
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = atomicWriteFile(localPath, data); err != nil {
|
if err = atomicWriteFile(localPath, data); err != nil {
|
||||||
@@ -282,6 +288,9 @@ func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, loca
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Warnf("management asset downloaded from fallback URL without digest verification (hash=%s) — "+
|
||||||
|
"enable verified GitHub updates by keeping disable-auto-update-panel set to false", downloadedHash)
|
||||||
|
|
||||||
if err = atomicWriteFile(localPath, data); err != nil {
|
if err = atomicWriteFile(localPath, data); err != nil {
|
||||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||||
return false
|
return false
|
||||||
@@ -392,10 +401,13 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string)
|
|||||||
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("read download body: %w", err)
|
return nil, "", fmt.Errorf("read download body: %w", err)
|
||||||
}
|
}
|
||||||
|
if int64(len(data)) > maxAssetDownloadSize {
|
||||||
|
return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize)
|
||||||
|
}
|
||||||
|
|
||||||
sum := sha256.Sum256(data)
|
sum := sha256.Sum256(data)
|
||||||
return data, hex.EncodeToString(sum[:]), nil
|
return data, hex.EncodeToString(sum[:]), nil
|
||||||
|
|||||||
@@ -30,6 +30,23 @@ type OAuthCallback struct {
|
|||||||
ErrorDescription string
|
ErrorDescription string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AsyncPrompt runs a prompt function in a goroutine and returns channels for
|
||||||
|
// the result. The returned channels are buffered (size 1) so the goroutine can
|
||||||
|
// complete even if the caller abandons the channels.
|
||||||
|
func AsyncPrompt(promptFn func(string) (string, error), message string) (<-chan string, <-chan error) {
|
||||||
|
inputCh := make(chan string, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
input, err := promptFn(message)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
inputCh <- input
|
||||||
|
}()
|
||||||
|
return inputCh, errCh
|
||||||
|
}
|
||||||
|
|
||||||
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
||||||
// It returns nil when the input is empty.
|
// It returns nil when the input is empty.
|
||||||
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
||||||
|
|||||||
@@ -1,12 +1,186 @@
|
|||||||
// Package registry provides model definitions and lookup helpers for various AI providers.
|
// Package registry provides model definitions and lookup helpers for various AI providers.
|
||||||
// Static model metadata is stored in model_definitions_static_data.go.
|
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// staticModelsJSON mirrors the top-level structure of models.json.
|
||||||
|
type staticModelsJSON struct {
|
||||||
|
Claude []*ModelInfo `json:"claude"`
|
||||||
|
Gemini []*ModelInfo `json:"gemini"`
|
||||||
|
Vertex []*ModelInfo `json:"vertex"`
|
||||||
|
GeminiCLI []*ModelInfo `json:"gemini-cli"`
|
||||||
|
AIStudio []*ModelInfo `json:"aistudio"`
|
||||||
|
CodexFree []*ModelInfo `json:"codex-free"`
|
||||||
|
CodexTeam []*ModelInfo `json:"codex-team"`
|
||||||
|
CodexPlus []*ModelInfo `json:"codex-plus"`
|
||||||
|
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||||
|
Qwen []*ModelInfo `json:"qwen"`
|
||||||
|
IFlow []*ModelInfo `json:"iflow"`
|
||||||
|
Kimi []*ModelInfo `json:"kimi"`
|
||||||
|
Antigravity []*ModelInfo `json:"antigravity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaudeModels returns the standard Claude model definitions.
|
||||||
|
func GetClaudeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Claude)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiModels returns the standard Gemini model definitions.
|
||||||
|
func GetGeminiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Gemini)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
|
||||||
|
func GetGeminiVertexModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Vertex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
|
||||||
|
func GetGeminiCLIModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().GeminiCLI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAIStudioModels returns model definitions for AI Studio.
|
||||||
|
func GetAIStudioModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().AIStudio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
|
||||||
|
func GetCodexFreeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexFree)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
|
||||||
|
func GetCodexTeamModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexTeam)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
|
||||||
|
func GetCodexPlusModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPlus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexProModels returns model definitions for the Codex pro plan tier.
|
||||||
|
func GetCodexProModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPro)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQwenModels returns the standard Qwen model definitions.
|
||||||
|
func GetQwenModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Qwen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIFlowModels returns the standard iFlow model definitions.
|
||||||
|
func GetIFlowModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().IFlow)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
|
||||||
|
func GetKimiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Kimi)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAntigravityModels returns the standard Antigravity model definitions.
|
||||||
|
func GetAntigravityModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Antigravity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodeBuddyModels returns the available models for CodeBuddy (Tencent).
|
||||||
|
// These models are served through the copilot.tencent.com API.
|
||||||
|
func GetCodeBuddyModels() []*ModelInfo {
|
||||||
|
now := int64(1748044800) // 2025-05-24
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "glm-5.0",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.0",
|
||||||
|
Description: "GLM-5.0 via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-4.7",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-4.7",
|
||||||
|
Description: "GLM-4.7 via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "minimax-m2.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "MiniMax M2.5",
|
||||||
|
Description: "MiniMax M2.5 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kimi-k2.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Kimi K2.5",
|
||||||
|
Description: "Kimi K2.5 via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "deepseek-v3-2-volc",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "DeepSeek V3.2 (Volc)",
|
||||||
|
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "hunyuan-2.0-thinking",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Hunyuan 2.0 Thinking",
|
||||||
|
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||||
|
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*ModelInfo, len(models))
|
||||||
|
for i, m := range models {
|
||||||
|
out[i] = cloneModelInfo(m)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
||||||
// It returns nil when the channel is unknown.
|
// It returns nil when the channel is unknown.
|
||||||
//
|
//
|
||||||
@@ -20,7 +194,6 @@ import (
|
|||||||
// - qwen
|
// - qwen
|
||||||
// - iflow
|
// - iflow
|
||||||
// - kimi
|
// - kimi
|
||||||
// - kiro
|
|
||||||
// - kilo
|
// - kilo
|
||||||
// - github-copilot
|
// - github-copilot
|
||||||
// - amazonq
|
// - amazonq
|
||||||
@@ -39,7 +212,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
case "aistudio":
|
case "aistudio":
|
||||||
return GetAIStudioModels()
|
return GetAIStudioModels()
|
||||||
case "codex":
|
case "codex":
|
||||||
return GetOpenAIModels()
|
return GetCodexProModels()
|
||||||
case "qwen":
|
case "qwen":
|
||||||
return GetQwenModels()
|
return GetQwenModels()
|
||||||
case "iflow":
|
case "iflow":
|
||||||
@@ -55,33 +228,28 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
case "amazonq":
|
case "amazonq":
|
||||||
return GetAmazonQModels()
|
return GetAmazonQModels()
|
||||||
case "antigravity":
|
case "antigravity":
|
||||||
cfg := GetAntigravityModelConfig()
|
return GetAntigravityModels()
|
||||||
if len(cfg) == 0 {
|
case "codebuddy":
|
||||||
return nil
|
return GetCodeBuddyModels()
|
||||||
}
|
case "cursor":
|
||||||
models := make([]*ModelInfo, 0, len(cfg))
|
return GetCursorModels()
|
||||||
for modelID, entry := range cfg {
|
|
||||||
if modelID == "" || entry == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
models = append(models, &ModelInfo{
|
|
||||||
ID: modelID,
|
|
||||||
Object: "model",
|
|
||||||
OwnedBy: "antigravity",
|
|
||||||
Type: "antigravity",
|
|
||||||
Thinking: entry.Thinking,
|
|
||||||
MaxCompletionTokens: entry.MaxCompletionTokens,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
sort.Slice(models, func(i, j int) bool {
|
|
||||||
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
|
|
||||||
})
|
|
||||||
return models
|
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCursorModels returns the fallback Cursor model definitions.
|
||||||
|
func GetCursorModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
|
||||||
|
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
|
||||||
|
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
|
||||||
|
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||||
// Returns nil if no matching model is found.
|
// Returns nil if no matching model is found.
|
||||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||||
@@ -89,38 +257,33 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data := getModels()
|
||||||
allModels := [][]*ModelInfo{
|
allModels := [][]*ModelInfo{
|
||||||
GetClaudeModels(),
|
data.Claude,
|
||||||
GetGeminiModels(),
|
data.Gemini,
|
||||||
GetGeminiVertexModels(),
|
data.Vertex,
|
||||||
GetGeminiCLIModels(),
|
data.GeminiCLI,
|
||||||
GetAIStudioModels(),
|
data.AIStudio,
|
||||||
GetOpenAIModels(),
|
data.CodexPro,
|
||||||
GetQwenModels(),
|
data.Qwen,
|
||||||
GetIFlowModels(),
|
data.IFlow,
|
||||||
GetKimiModels(),
|
data.Kimi,
|
||||||
|
data.Antigravity,
|
||||||
GetGitHubCopilotModels(),
|
GetGitHubCopilotModels(),
|
||||||
GetKiroModels(),
|
GetKiroModels(),
|
||||||
GetKiloModels(),
|
GetKiloModels(),
|
||||||
GetAmazonQModels(),
|
GetAmazonQModels(),
|
||||||
|
GetCodeBuddyModels(),
|
||||||
|
GetCursorModels(),
|
||||||
}
|
}
|
||||||
for _, models := range allModels {
|
for _, models := range allModels {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if m != nil && m.ID == modelID {
|
if m != nil && m.ID == modelID {
|
||||||
return m
|
return cloneModelInfo(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check Antigravity static config
|
|
||||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
|
||||||
return &ModelInfo{
|
|
||||||
ID: modelID,
|
|
||||||
Thinking: cfg.Thinking,
|
|
||||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,6 +464,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
SupportedEndpoints: []string{"/responses"},
|
SupportedEndpoints: []string{"/responses"},
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.4",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.4",
|
||||||
|
Description: "OpenAI GPT-5.4 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/responses"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-haiku-4.5",
|
ID: "claude-haiku-4.5",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -73,16 +73,16 @@ type availableModelsCacheEntry struct {
|
|||||||
// Values are interpreted in provider-native token units.
|
// Values are interpreted in provider-native token units.
|
||||||
type ThinkingSupport struct {
|
type ThinkingSupport struct {
|
||||||
// Min is the minimum allowed thinking budget (inclusive).
|
// Min is the minimum allowed thinking budget (inclusive).
|
||||||
Min int `json:"min,omitempty"`
|
Min int `json:"min,omitempty" yaml:"min,omitempty"`
|
||||||
// Max is the maximum allowed thinking budget (inclusive).
|
// Max is the maximum allowed thinking budget (inclusive).
|
||||||
Max int `json:"max,omitempty"`
|
Max int `json:"max,omitempty" yaml:"max,omitempty"`
|
||||||
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
|
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
|
||||||
ZeroAllowed bool `json:"zero_allowed,omitempty"`
|
ZeroAllowed bool `json:"zero_allowed,omitempty" yaml:"zero-allowed,omitempty"`
|
||||||
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
|
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
|
||||||
DynamicAllowed bool `json:"dynamic_allowed,omitempty"`
|
DynamicAllowed bool `json:"dynamic_allowed,omitempty" yaml:"dynamic-allowed,omitempty"`
|
||||||
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
|
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
|
||||||
// When set, the model uses level-based reasoning instead of token budgets.
|
// When set, the model uses level-based reasoning instead of token budgets.
|
||||||
Levels []string `json:"levels,omitempty"`
|
Levels []string `json:"levels,omitempty" yaml:"levels,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelRegistration tracks a model's availability
|
// ModelRegistration tracks a model's availability
|
||||||
@@ -189,6 +189,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const defaultModelRegistryHookTimeout = 5 * time.Second
|
const defaultModelRegistryHookTimeout = 5 * time.Second
|
||||||
|
const modelQuotaExceededWindow = 5 * time.Minute
|
||||||
|
|
||||||
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
||||||
hook := r.hook
|
hook := r.hook
|
||||||
@@ -390,6 +391,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
}
|
}
|
||||||
reg.LastUpdated = now
|
reg.LastUpdated = now
|
||||||
|
// Re-registering an existing client/model binding starts a fresh registry
|
||||||
|
// snapshot for that binding. Cooldown and suspension are transient
|
||||||
|
// scheduling state and must not survive this reconciliation step.
|
||||||
if reg.QuotaExceededClients != nil {
|
if reg.QuotaExceededClients != nil {
|
||||||
delete(reg.QuotaExceededClients, clientID)
|
delete(reg.QuotaExceededClients, clientID)
|
||||||
}
|
}
|
||||||
@@ -783,7 +787,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
|||||||
|
|
||||||
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
||||||
models := make([]map[string]any, 0, len(r.models))
|
models := make([]map[string]any, 0, len(r.models))
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
|
||||||
var expiresAt time.Time
|
var expiresAt time.Time
|
||||||
|
|
||||||
for _, registration := range r.models {
|
for _, registration := range r.models {
|
||||||
@@ -794,7 +797,7 @@ func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.
|
|||||||
if quotaTime == nil {
|
if quotaTime == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
recoveryAt := quotaTime.Add(quotaExpiredDuration)
|
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
|
||||||
if now.Before(recoveryAt) {
|
if now.Before(recoveryAt) {
|
||||||
expiredClients++
|
expiredClients++
|
||||||
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
||||||
@@ -929,7 +932,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
result := make([]*ModelInfo, 0, len(providerModels))
|
result := make([]*ModelInfo, 0, len(providerModels))
|
||||||
|
|
||||||
@@ -951,7 +953,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
|||||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||||
expiredClients++
|
expiredClients++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1005,12 +1007,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
|
|||||||
|
|
||||||
if registration, exists := r.models[modelID]; exists {
|
if registration, exists := r.models[modelID]; exists {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
|
||||||
|
|
||||||
// Count clients that have exceeded quota but haven't recovered yet
|
// Count clients that have exceeded quota but haven't recovered yet
|
||||||
expiredClients := 0
|
expiredClients := 0
|
||||||
for _, quotaTime := range registration.QuotaExceededClients {
|
for _, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||||
expiredClients++
|
expiredClients++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1236,12 +1237,11 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
|||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
|
||||||
invalidated := false
|
invalidated := false
|
||||||
|
|
||||||
for modelID, registration := range r.models {
|
for modelID, registration := range r.models {
|
||||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
|
||||||
delete(registration.QuotaExceededClients, clientID)
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
invalidated = true
|
invalidated = true
|
||||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||||
|
|||||||
372
internal/registry/model_updater.go
Normal file
372
internal/registry/model_updater.go
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
modelsFetchTimeout = 30 * time.Second
|
||||||
|
modelsRefreshInterval = 3 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var modelsURLs = []string{
|
||||||
|
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
|
||||||
|
"https://models.router-for.me/models.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:embed models/models.json
|
||||||
|
var embeddedModelsJSON []byte
|
||||||
|
|
||||||
|
type modelStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
data *staticModelsJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelsCatalogStore = &modelStore{}
|
||||||
|
|
||||||
|
var updaterOnce sync.Once
|
||||||
|
|
||||||
|
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
|
||||||
|
// changedProviders contains the provider names whose model definitions changed.
|
||||||
|
type ModelRefreshCallback func(changedProviders []string)
|
||||||
|
|
||||||
|
var (
|
||||||
|
refreshCallbackMu sync.Mutex
|
||||||
|
refreshCallback ModelRefreshCallback
|
||||||
|
pendingRefreshChanges []string
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetModelRefreshCallback registers a callback that is invoked when startup or
|
||||||
|
// periodic model refresh detects changes. Only one callback is supported;
|
||||||
|
// subsequent calls replace the previous callback.
|
||||||
|
func SetModelRefreshCallback(cb ModelRefreshCallback) {
|
||||||
|
refreshCallbackMu.Lock()
|
||||||
|
refreshCallback = cb
|
||||||
|
var pending []string
|
||||||
|
if cb != nil && len(pendingRefreshChanges) > 0 {
|
||||||
|
pending = append([]string(nil), pendingRefreshChanges...)
|
||||||
|
pendingRefreshChanges = nil
|
||||||
|
}
|
||||||
|
refreshCallbackMu.Unlock()
|
||||||
|
|
||||||
|
if cb != nil && len(pending) > 0 {
|
||||||
|
cb(pending)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Load embedded data as fallback on startup.
|
||||||
|
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
|
||||||
|
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartModelsUpdater starts a background updater that fetches models
|
||||||
|
// immediately on startup and then refreshes the model catalog every 3 hours.
|
||||||
|
// Safe to call multiple times; only one updater will run.
|
||||||
|
func StartModelsUpdater(ctx context.Context) {
|
||||||
|
updaterOnce.Do(func() {
|
||||||
|
go runModelsUpdater(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runModelsUpdater(ctx context.Context) {
|
||||||
|
tryStartupRefresh(ctx)
|
||||||
|
periodicRefresh(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func periodicRefresh(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(modelsRefreshInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
tryPeriodicRefresh(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryPeriodicRefresh fetches models from remote, compares with the current
|
||||||
|
// catalog, and notifies the registered callback if any provider changed.
|
||||||
|
func tryPeriodicRefresh(ctx context.Context) {
|
||||||
|
tryRefreshModels(ctx, "periodic model refresh")
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryStartupRefresh fetches models from remote in the background during
|
||||||
|
// process startup. It uses the same change detection as periodic refresh so
|
||||||
|
// existing auth registrations can be updated after the callback is registered.
|
||||||
|
func tryStartupRefresh(ctx context.Context) {
|
||||||
|
tryRefreshModels(ctx, "startup model refresh")
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryRefreshModels(ctx context.Context, label string) {
|
||||||
|
oldData := getModels()
|
||||||
|
|
||||||
|
parsed, url := fetchModelsFromRemote(ctx)
|
||||||
|
if parsed == nil {
|
||||||
|
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect changes before updating store.
|
||||||
|
changed := detectChangedProviders(oldData, parsed)
|
||||||
|
|
||||||
|
// Update store with new data regardless.
|
||||||
|
modelsCatalogStore.mu.Lock()
|
||||||
|
modelsCatalogStore.data = parsed
|
||||||
|
modelsCatalogStore.mu.Unlock()
|
||||||
|
|
||||||
|
if len(changed) == 0 {
|
||||||
|
log.Infof("%s completed from %s, no changes detected", label, url)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
|
||||||
|
notifyModelRefresh(changed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
|
||||||
|
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
|
||||||
|
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
|
||||||
|
client := &http.Client{Timeout: modelsFetchTimeout}
|
||||||
|
for _, url := range modelsURLs {
|
||||||
|
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
|
||||||
|
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch request creation failed for %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch failed from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.Body.Close()
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("models fetch read error from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed staticModelsJSON
|
||||||
|
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||||
|
log.Warnf("models parse failed from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := validateModelsCatalog(&parsed); err != nil {
|
||||||
|
log.Warnf("models validate failed from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return &parsed, url
|
||||||
|
}
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectChangedProviders compares two model catalogs and returns provider names
|
||||||
|
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
|
||||||
|
// under a single "codex" provider.
|
||||||
|
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
||||||
|
if oldData == nil || newData == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type section struct {
|
||||||
|
provider string
|
||||||
|
oldList []*ModelInfo
|
||||||
|
newList []*ModelInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
sections := []section{
|
||||||
|
{"claude", oldData.Claude, newData.Claude},
|
||||||
|
{"gemini", oldData.Gemini, newData.Gemini},
|
||||||
|
{"vertex", oldData.Vertex, newData.Vertex},
|
||||||
|
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
|
||||||
|
{"aistudio", oldData.AIStudio, newData.AIStudio},
|
||||||
|
{"codex", oldData.CodexFree, newData.CodexFree},
|
||||||
|
{"codex", oldData.CodexTeam, newData.CodexTeam},
|
||||||
|
{"codex", oldData.CodexPlus, newData.CodexPlus},
|
||||||
|
{"codex", oldData.CodexPro, newData.CodexPro},
|
||||||
|
{"qwen", oldData.Qwen, newData.Qwen},
|
||||||
|
{"iflow", oldData.IFlow, newData.IFlow},
|
||||||
|
{"kimi", oldData.Kimi, newData.Kimi},
|
||||||
|
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool, len(sections))
|
||||||
|
var changed []string
|
||||||
|
for _, s := range sections {
|
||||||
|
if seen[s.provider] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if modelSectionChanged(s.oldList, s.newList) {
|
||||||
|
changed = append(changed, s.provider)
|
||||||
|
seen[s.provider] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changed
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelSectionChanged reports whether two model slices differ.
|
||||||
|
func modelSectionChanged(a, b []*ModelInfo) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
aj, err1 := json.Marshal(a)
|
||||||
|
bj, err2 := json.Marshal(b)
|
||||||
|
if err1 != nil || err2 != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return string(aj) != string(bj)
|
||||||
|
}
|
||||||
|
|
||||||
|
func notifyModelRefresh(changedProviders []string) {
|
||||||
|
if len(changedProviders) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshCallbackMu.Lock()
|
||||||
|
cb := refreshCallback
|
||||||
|
if cb == nil {
|
||||||
|
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
|
||||||
|
refreshCallbackMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
refreshCallbackMu.Unlock()
|
||||||
|
cb(changedProviders)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeProviderNames(existing, incoming []string) []string {
|
||||||
|
if len(incoming) == 0 {
|
||||||
|
return existing
|
||||||
|
}
|
||||||
|
seen := make(map[string]struct{}, len(existing)+len(incoming))
|
||||||
|
merged := make([]string, 0, len(existing)+len(incoming))
|
||||||
|
for _, provider := range existing {
|
||||||
|
name := strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[name] = struct{}{}
|
||||||
|
merged = append(merged, name)
|
||||||
|
}
|
||||||
|
for _, provider := range incoming {
|
||||||
|
name := strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[name] = struct{}{}
|
||||||
|
merged = append(merged, name)
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadModelsFromBytes(data []byte, source string) error {
|
||||||
|
var parsed staticModelsJSON
|
||||||
|
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||||
|
return fmt.Errorf("%s: decode models catalog: %w", source, err)
|
||||||
|
}
|
||||||
|
if err := validateModelsCatalog(&parsed); err != nil {
|
||||||
|
return fmt.Errorf("%s: validate models catalog: %w", source, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsCatalogStore.mu.Lock()
|
||||||
|
modelsCatalogStore.data = &parsed
|
||||||
|
modelsCatalogStore.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getModels() *staticModelsJSON {
|
||||||
|
modelsCatalogStore.mu.RLock()
|
||||||
|
defer modelsCatalogStore.mu.RUnlock()
|
||||||
|
return modelsCatalogStore.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateModelsCatalog(data *staticModelsJSON) error {
|
||||||
|
if data == nil {
|
||||||
|
return fmt.Errorf("catalog is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
requiredSections := []struct {
|
||||||
|
name string
|
||||||
|
models []*ModelInfo
|
||||||
|
}{
|
||||||
|
{name: "claude", models: data.Claude},
|
||||||
|
{name: "gemini", models: data.Gemini},
|
||||||
|
{name: "vertex", models: data.Vertex},
|
||||||
|
{name: "gemini-cli", models: data.GeminiCLI},
|
||||||
|
{name: "aistudio", models: data.AIStudio},
|
||||||
|
{name: "codex-free", models: data.CodexFree},
|
||||||
|
{name: "codex-team", models: data.CodexTeam},
|
||||||
|
{name: "codex-plus", models: data.CodexPlus},
|
||||||
|
{name: "codex-pro", models: data.CodexPro},
|
||||||
|
{name: "qwen", models: data.Qwen},
|
||||||
|
{name: "iflow", models: data.IFlow},
|
||||||
|
{name: "kimi", models: data.Kimi},
|
||||||
|
{name: "antigravity", models: data.Antigravity},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, section := range requiredSections {
|
||||||
|
if err := validateModelSection(section.name, section.models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateModelSection(section string, models []*ModelInfo) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return fmt.Errorf("%s section is empty", section)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, len(models))
|
||||||
|
for i, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
return fmt.Errorf("%s[%d] is null", section, i)
|
||||||
|
}
|
||||||
|
modelID := strings.TrimSpace(model.ID)
|
||||||
|
if modelID == "" {
|
||||||
|
return fmt.Errorf("%s[%d] has empty id", section, i)
|
||||||
|
}
|
||||||
|
if _, exists := seen[modelID]; exists {
|
||||||
|
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
|
||||||
|
}
|
||||||
|
seen[modelID] = struct{}{}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
2683
internal/registry/models/models.json
Normal file
2683
internal/registry/models/models.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -164,7 +164,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,7 +280,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -296,7 +296,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||||
}
|
}
|
||||||
reporter.publish(ctx, parseGeminiUsage(event.Payload))
|
reporter.publish(ctx, parseGeminiUsage(event.Payload))
|
||||||
return false
|
return false
|
||||||
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||||
}
|
}
|
||||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
|
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh refreshes the authentication credentials (no-op for AI Studio).
|
// Refresh refreshes the authentication credentials (no-op for AI Studio).
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
@@ -43,7 +42,6 @@ const (
|
|||||||
antigravityCountTokensPath = "/v1internal:countTokens"
|
antigravityCountTokensPath = "/v1internal:countTokens"
|
||||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||||
antigravityGeneratePath = "/v1internal:generateContent"
|
antigravityGeneratePath = "/v1internal:generateContent"
|
||||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
antigravityClientSecret = "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"
|
antigravityClientSecret = "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"
|
||||||
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
||||||
@@ -55,78 +53,8 @@ const (
|
|||||||
var (
|
var (
|
||||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
randSourceMutex sync.Mutex
|
randSourceMutex sync.Mutex
|
||||||
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
|
|
||||||
// from any antigravity auth. Empty fetches never overwrite this cache.
|
|
||||||
antigravityPrimaryModelsCache struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
models []*registry.ModelInfo
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
|
|
||||||
if len(models) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := make([]*registry.ModelInfo, 0, len(models))
|
|
||||||
for _, model := range models {
|
|
||||||
if model == nil || strings.TrimSpace(model.ID) == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out = append(out, cloneAntigravityModelInfo(model))
|
|
||||||
}
|
|
||||||
if len(out) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
|
|
||||||
if model == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
clone := *model
|
|
||||||
if len(model.SupportedGenerationMethods) > 0 {
|
|
||||||
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
|
||||||
}
|
|
||||||
if len(model.SupportedParameters) > 0 {
|
|
||||||
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
|
||||||
}
|
|
||||||
if model.Thinking != nil {
|
|
||||||
thinkingClone := *model.Thinking
|
|
||||||
if len(model.Thinking.Levels) > 0 {
|
|
||||||
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
|
|
||||||
}
|
|
||||||
clone.Thinking = &thinkingClone
|
|
||||||
}
|
|
||||||
return &clone
|
|
||||||
}
|
|
||||||
|
|
||||||
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
|
|
||||||
cloned := cloneAntigravityModels(models)
|
|
||||||
if len(cloned) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
antigravityPrimaryModelsCache.mu.Lock()
|
|
||||||
antigravityPrimaryModelsCache.models = cloned
|
|
||||||
antigravityPrimaryModelsCache.mu.Unlock()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
|
|
||||||
antigravityPrimaryModelsCache.mu.RLock()
|
|
||||||
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
|
|
||||||
antigravityPrimaryModelsCache.mu.RUnlock()
|
|
||||||
return cloned
|
|
||||||
}
|
|
||||||
|
|
||||||
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
|
|
||||||
models := loadAntigravityPrimaryModels()
|
|
||||||
if len(models) > 0 {
|
|
||||||
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
|
|
||||||
}
|
|
||||||
return models
|
|
||||||
}
|
|
||||||
|
|
||||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||||
type AntigravityExecutor struct {
|
type AntigravityExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -380,7 +308,7 @@ attemptLoop:
|
|||||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||||
var param any
|
var param any
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -584,7 +512,7 @@ attemptLoop:
|
|||||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||||
var param any
|
var param any
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@@ -763,31 +691,42 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
partsJSON, _ := json.Marshal(parts)
|
partsJSON, _ := json.Marshal(parts)
|
||||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
|
updatedTemplate, _ := sjson.SetRawBytes([]byte(responseTemplate), "candidates.0.content.parts", partsJSON)
|
||||||
|
responseTemplate = string(updatedTemplate)
|
||||||
if role != "" {
|
if role != "" {
|
||||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
|
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.content.role", role)
|
||||||
|
responseTemplate = string(updatedTemplate)
|
||||||
}
|
}
|
||||||
if finishReason != "" {
|
if finishReason != "" {
|
||||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
|
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.finishReason", finishReason)
|
||||||
|
responseTemplate = string(updatedTemplate)
|
||||||
}
|
}
|
||||||
if modelVersion != "" {
|
if modelVersion != "" {
|
||||||
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
|
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "modelVersion", modelVersion)
|
||||||
|
responseTemplate = string(updatedTemplate)
|
||||||
}
|
}
|
||||||
if responseID != "" {
|
if responseID != "" {
|
||||||
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
|
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "responseId", responseID)
|
||||||
|
responseTemplate = string(updatedTemplate)
|
||||||
}
|
}
|
||||||
if usageRaw != "" {
|
if usageRaw != "" {
|
||||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
|
updatedTemplate, _ = sjson.SetRawBytes([]byte(responseTemplate), "usageMetadata", []byte(usageRaw))
|
||||||
|
responseTemplate = string(updatedTemplate)
|
||||||
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
|
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
|
||||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
|
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.promptTokenCount", 0)
|
||||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
|
responseTemplate = string(updatedTemplate)
|
||||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
|
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.candidatesTokenCount", 0)
|
||||||
|
responseTemplate = string(updatedTemplate)
|
||||||
|
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.totalTokenCount", 0)
|
||||||
|
responseTemplate = string(updatedTemplate)
|
||||||
}
|
}
|
||||||
|
|
||||||
output := `{"response":{},"traceId":""}`
|
output := `{"response":{},"traceId":""}`
|
||||||
output, _ = sjson.SetRaw(output, "response", responseTemplate)
|
updatedOutput, _ := sjson.SetRawBytes([]byte(output), "response", []byte(responseTemplate))
|
||||||
|
output = string(updatedOutput)
|
||||||
if traceID != "" {
|
if traceID != "" {
|
||||||
output, _ = sjson.Set(output, "traceId", traceID)
|
updatedOutput, _ = sjson.SetBytes([]byte(output), "traceId", traceID)
|
||||||
|
output = string(updatedOutput)
|
||||||
}
|
}
|
||||||
return []byte(output)
|
return []byte(output)
|
||||||
}
|
}
|
||||||
@@ -952,12 +891,12 @@ attemptLoop:
|
|||||||
|
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
||||||
for i := range tail {
|
for i := range tail {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
@@ -1115,7 +1054,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
return cliproxyexecutor.Response{Payload: translated, Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
@@ -1150,168 +1089,6 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchAntigravityModels retrieves available models using the supplied auth.
|
|
||||||
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
|
||||||
exec := &AntigravityExecutor{cfg: cfg}
|
|
||||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
|
||||||
if errToken != nil || token == "" {
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
if updatedAuth != nil {
|
|
||||||
auth = updatedAuth
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
|
||||||
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
|
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
|
||||||
modelsURL := baseURL + antigravityModelsPath
|
|
||||||
|
|
||||||
var payload []byte
|
|
||||||
if auth != nil && auth.Metadata != nil {
|
|
||||||
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
|
||||||
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(payload) == 0 {
|
|
||||||
payload = []byte(`{}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
|
|
||||||
if errReq != nil {
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
httpReq.Close = true
|
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
|
||||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
|
||||||
if host := resolveHost(baseURL); host != "" {
|
|
||||||
httpReq.Host = host
|
|
||||||
}
|
|
||||||
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
|
||||||
if errDo != nil {
|
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
if idx+1 < len(baseURLs) {
|
|
||||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
if errRead != nil {
|
|
||||||
if idx+1 < len(baseURLs) {
|
|
||||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
|
||||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if idx+1 < len(baseURLs) {
|
|
||||||
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
|
|
||||||
result := gjson.GetBytes(bodyBytes, "models")
|
|
||||||
if !result.Exists() {
|
|
||||||
if idx+1 < len(baseURLs) {
|
|
||||||
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().Unix()
|
|
||||||
modelConfig := registry.GetAntigravityModelConfig()
|
|
||||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
|
||||||
for originalName, modelData := range result.Map() {
|
|
||||||
modelID := strings.TrimSpace(originalName)
|
|
||||||
if modelID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch modelID {
|
|
||||||
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
modelCfg := modelConfig[modelID]
|
|
||||||
|
|
||||||
// Extract displayName from upstream response, fallback to modelID
|
|
||||||
displayName := modelData.Get("displayName").String()
|
|
||||||
if displayName == "" {
|
|
||||||
displayName = modelID
|
|
||||||
}
|
|
||||||
|
|
||||||
modelInfo := ®istry.ModelInfo{
|
|
||||||
ID: modelID,
|
|
||||||
Name: modelID,
|
|
||||||
Description: displayName,
|
|
||||||
DisplayName: displayName,
|
|
||||||
Version: modelID,
|
|
||||||
Object: "model",
|
|
||||||
Created: now,
|
|
||||||
OwnedBy: antigravityAuthType,
|
|
||||||
Type: antigravityAuthType,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build input modalities from upstream capability flags.
|
|
||||||
inputModalities := []string{"TEXT"}
|
|
||||||
if modelData.Get("supportsImages").Bool() {
|
|
||||||
inputModalities = append(inputModalities, "IMAGE")
|
|
||||||
}
|
|
||||||
if modelData.Get("supportsVideo").Bool() {
|
|
||||||
inputModalities = append(inputModalities, "VIDEO")
|
|
||||||
}
|
|
||||||
modelInfo.SupportedInputModalities = inputModalities
|
|
||||||
modelInfo.SupportedOutputModalities = []string{"TEXT"}
|
|
||||||
|
|
||||||
// Token limits from upstream.
|
|
||||||
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
|
||||||
modelInfo.InputTokenLimit = int(maxTok)
|
|
||||||
}
|
|
||||||
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
|
||||||
modelInfo.OutputTokenLimit = int(maxOut)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Supported generation methods (Gemini v1beta convention).
|
|
||||||
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
|
|
||||||
|
|
||||||
// Look up Thinking support from static config using upstream model name.
|
|
||||||
if modelCfg != nil {
|
|
||||||
if modelCfg.Thinking != nil {
|
|
||||||
modelInfo.Thinking = modelCfg.Thinking
|
|
||||||
}
|
|
||||||
if modelCfg.MaxCompletionTokens > 0 {
|
|
||||||
modelInfo.MaxCompletionTokens = modelCfg.MaxCompletionTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
models = append(models, modelInfo)
|
|
||||||
}
|
|
||||||
if len(models) == 0 {
|
|
||||||
if idx+1 < len(baseURLs) {
|
|
||||||
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
storeAntigravityPrimaryModels(models)
|
|
||||||
return models
|
|
||||||
}
|
|
||||||
return fallbackAntigravityPrimaryModels()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||||
@@ -1499,19 +1276,20 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
|
|
||||||
// if useAntigravitySchema {
|
// if useAntigravitySchema {
|
||||||
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.role", "user")
|
||||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.0.text", systemInstruction)
|
||||||
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||||
|
|
||||||
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||||
// for _, partResult := range systemInstructionPartsResult.Array() {
|
// for _, partResult := range systemInstructionPartsResult.Array() {
|
||||||
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
// payloadStr, _ = sjson.SetRawBytes([]byte(payloadStr), "request.systemInstruction.parts.-1", []byte(partResult.Raw))
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
if strings.Contains(modelName, "claude") {
|
if strings.Contains(modelName, "claude") {
|
||||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||||
|
payloadStr = string(updated)
|
||||||
} else {
|
} else {
|
||||||
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||||
}
|
}
|
||||||
@@ -1733,8 +1511,9 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
template := payload
|
||||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
template, _ = sjson.SetBytes(template, "model", modelName)
|
||||||
|
template, _ = sjson.SetBytes(template, "userAgent", "antigravity")
|
||||||
|
|
||||||
isImageModel := strings.Contains(modelName, "image")
|
isImageModel := strings.Contains(modelName, "image")
|
||||||
|
|
||||||
@@ -1744,28 +1523,28 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
|||||||
} else {
|
} else {
|
||||||
reqType = "agent"
|
reqType = "agent"
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "requestType", reqType)
|
template, _ = sjson.SetBytes(template, "requestType", reqType)
|
||||||
|
|
||||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
template, _ = sjson.Set(template, "project", projectID)
|
template, _ = sjson.SetBytes(template, "project", projectID)
|
||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
template, _ = sjson.SetBytes(template, "project", generateProjectID())
|
||||||
}
|
}
|
||||||
|
|
||||||
if isImageModel {
|
if isImageModel {
|
||||||
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
|
template, _ = sjson.SetBytes(template, "requestId", generateImageGenRequestID())
|
||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
template, _ = sjson.SetBytes(template, "requestId", generateRequestID())
|
||||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
template, _ = sjson.SetBytes(template, "request.sessionId", generateStableSessionID(payload))
|
||||||
}
|
}
|
||||||
|
|
||||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
template, _ = sjson.DeleteBytes(template, "request.safetySettings")
|
||||||
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
if toolConfig := gjson.GetBytes(template, "toolConfig"); toolConfig.Exists() && !gjson.GetBytes(template, "request.toolConfig").Exists() {
|
||||||
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
|
template, _ = sjson.SetRawBytes(template, "request.toolConfig", []byte(toolConfig.Raw))
|
||||||
template, _ = sjson.Delete(template, "toolConfig")
|
template, _ = sjson.DeleteBytes(template, "toolConfig")
|
||||||
}
|
}
|
||||||
return []byte(template)
|
return template
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateRequestID() string {
|
func generateRequestID() string {
|
||||||
|
|||||||
@@ -1,90 +0,0 @@
|
|||||||
package executor
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
func resetAntigravityPrimaryModelsCacheForTest() {
|
|
||||||
antigravityPrimaryModelsCache.mu.Lock()
|
|
||||||
antigravityPrimaryModelsCache.models = nil
|
|
||||||
antigravityPrimaryModelsCache.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
|
|
||||||
resetAntigravityPrimaryModelsCacheForTest()
|
|
||||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
|
||||||
|
|
||||||
seed := []*registry.ModelInfo{
|
|
||||||
{ID: "claude-sonnet-4-5"},
|
|
||||||
{ID: "gemini-2.5-pro"},
|
|
||||||
}
|
|
||||||
if updated := storeAntigravityPrimaryModels(seed); !updated {
|
|
||||||
t.Fatal("expected non-empty model list to update primary cache")
|
|
||||||
}
|
|
||||||
|
|
||||||
if updated := storeAntigravityPrimaryModels(nil); updated {
|
|
||||||
t.Fatal("expected nil model list not to overwrite primary cache")
|
|
||||||
}
|
|
||||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
|
|
||||||
t.Fatal("expected empty model list not to overwrite primary cache")
|
|
||||||
}
|
|
||||||
|
|
||||||
got := loadAntigravityPrimaryModels()
|
|
||||||
if len(got) != 2 {
|
|
||||||
t.Fatalf("expected cached model count 2, got %d", len(got))
|
|
||||||
}
|
|
||||||
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
|
|
||||||
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
|
|
||||||
resetAntigravityPrimaryModelsCacheForTest()
|
|
||||||
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
|
|
||||||
|
|
||||||
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
|
|
||||||
ID: "gpt-5",
|
|
||||||
DisplayName: "GPT-5",
|
|
||||||
SupportedGenerationMethods: []string{"generateContent"},
|
|
||||||
SupportedParameters: []string{"temperature"},
|
|
||||||
Thinking: ®istry.ThinkingSupport{
|
|
||||||
Levels: []string{"high"},
|
|
||||||
},
|
|
||||||
}}); !updated {
|
|
||||||
t.Fatal("expected model cache update")
|
|
||||||
}
|
|
||||||
|
|
||||||
got := loadAntigravityPrimaryModels()
|
|
||||||
if len(got) != 1 {
|
|
||||||
t.Fatalf("expected one cached model, got %d", len(got))
|
|
||||||
}
|
|
||||||
got[0].ID = "mutated-id"
|
|
||||||
if len(got[0].SupportedGenerationMethods) > 0 {
|
|
||||||
got[0].SupportedGenerationMethods[0] = "mutated-method"
|
|
||||||
}
|
|
||||||
if len(got[0].SupportedParameters) > 0 {
|
|
||||||
got[0].SupportedParameters[0] = "mutated-parameter"
|
|
||||||
}
|
|
||||||
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
|
|
||||||
got[0].Thinking.Levels[0] = "mutated-level"
|
|
||||||
}
|
|
||||||
|
|
||||||
again := loadAntigravityPrimaryModels()
|
|
||||||
if len(again) != 1 {
|
|
||||||
t.Fatalf("expected one cached model after mutation, got %d", len(again))
|
|
||||||
}
|
|
||||||
if again[0].ID != "gpt-5" {
|
|
||||||
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
|
|
||||||
}
|
|
||||||
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
|
|
||||||
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
|
|
||||||
}
|
|
||||||
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
|
|
||||||
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
|
|
||||||
}
|
|
||||||
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
|
|
||||||
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
383
internal/runtime/executor/claude_device_profile.go
Normal file
383
internal/runtime/executor/claude_device_profile.go
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultClaudeFingerprintUserAgent = "claude-cli/2.1.63 (external, cli)"
|
||||||
|
defaultClaudeFingerprintPackageVersion = "0.74.0"
|
||||||
|
defaultClaudeFingerprintRuntimeVersion = "v24.3.0"
|
||||||
|
defaultClaudeFingerprintOS = "MacOS"
|
||||||
|
defaultClaudeFingerprintArch = "arm64"
|
||||||
|
claudeDeviceProfileTTL = 7 * 24 * time.Hour
|
||||||
|
claudeDeviceProfileCleanupPeriod = time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
claudeCLIVersionPattern = regexp.MustCompile(`^claude-cli/(\d+)\.(\d+)\.(\d+)`)
|
||||||
|
|
||||||
|
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||||
|
claudeDeviceProfileCacheMu sync.RWMutex
|
||||||
|
claudeDeviceProfileCacheCleanupOnce sync.Once
|
||||||
|
|
||||||
|
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeCLIVersion struct {
|
||||||
|
major int
|
||||||
|
minor int
|
||||||
|
patch int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
|
||||||
|
switch {
|
||||||
|
case v.major != other.major:
|
||||||
|
if v.major > other.major {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
case v.minor != other.minor:
|
||||||
|
if v.minor > other.minor {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
case v.patch != other.patch:
|
||||||
|
if v.patch > other.patch {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type claudeDeviceProfile struct {
|
||||||
|
UserAgent string
|
||||||
|
PackageVersion string
|
||||||
|
RuntimeVersion string
|
||||||
|
OS string
|
||||||
|
Arch string
|
||||||
|
Version claudeCLIVersion
|
||||||
|
HasVersion bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type claudeDeviceProfileCacheEntry struct {
|
||||||
|
profile claudeDeviceProfile
|
||||||
|
expire time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
||||||
|
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
||||||
|
hdrDefault := func(cfgVal, fallback string) string {
|
||||||
|
if strings.TrimSpace(cfgVal) != "" {
|
||||||
|
return strings.TrimSpace(cfgVal)
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
var hd config.ClaudeHeaderDefaults
|
||||||
|
if cfg != nil {
|
||||||
|
hd = cfg.ClaudeHeaderDefaults
|
||||||
|
}
|
||||||
|
|
||||||
|
profile := claudeDeviceProfile{
|
||||||
|
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
|
||||||
|
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
|
||||||
|
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
|
||||||
|
OS: hdrDefault(hd.OS, defaultClaudeFingerprintOS),
|
||||||
|
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
|
||||||
|
}
|
||||||
|
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||||
|
profile.Version = version
|
||||||
|
profile.HasVersion = true
|
||||||
|
}
|
||||||
|
return profile
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||||
|
func mapStainlessOS() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return "MacOS"
|
||||||
|
case "windows":
|
||||||
|
return "Windows"
|
||||||
|
case "linux":
|
||||||
|
return "Linux"
|
||||||
|
case "freebsd":
|
||||||
|
return "FreeBSD"
|
||||||
|
default:
|
||||||
|
return "Other::" + runtime.GOOS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||||
|
func mapStainlessArch() string {
|
||||||
|
switch runtime.GOARCH {
|
||||||
|
case "amd64":
|
||||||
|
return "x64"
|
||||||
|
case "arm64":
|
||||||
|
return "arm64"
|
||||||
|
case "386":
|
||||||
|
return "x86"
|
||||||
|
default:
|
||||||
|
return "other::" + runtime.GOARCH
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
|
||||||
|
matches := claudeCLIVersionPattern.FindStringSubmatch(strings.TrimSpace(userAgent))
|
||||||
|
if len(matches) != 4 {
|
||||||
|
return claudeCLIVersion{}, false
|
||||||
|
}
|
||||||
|
major, err := strconv.Atoi(matches[1])
|
||||||
|
if err != nil {
|
||||||
|
return claudeCLIVersion{}, false
|
||||||
|
}
|
||||||
|
minor, err := strconv.Atoi(matches[2])
|
||||||
|
if err != nil {
|
||||||
|
return claudeCLIVersion{}, false
|
||||||
|
}
|
||||||
|
patch, err := strconv.Atoi(matches[3])
|
||||||
|
if err != nil {
|
||||||
|
return claudeCLIVersion{}, false
|
||||||
|
}
|
||||||
|
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
|
||||||
|
if candidate.UserAgent == "" || !candidate.HasVersion {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if current.UserAgent == "" || !current.HasVersion {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return candidate.Version.Compare(current.Version) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||||
|
profile.OS = baseline.OS
|
||||||
|
profile.Arch = baseline.Arch
|
||||||
|
return profile
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
|
||||||
|
// baseline platform and enforces the baseline software fingerprint as a floor.
|
||||||
|
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
||||||
|
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
|
||||||
|
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
||||||
|
profile.UserAgent = baseline.UserAgent
|
||||||
|
profile.PackageVersion = baseline.PackageVersion
|
||||||
|
profile.RuntimeVersion = baseline.RuntimeVersion
|
||||||
|
profile.Version = baseline.Version
|
||||||
|
profile.HasVersion = baseline.HasVersion
|
||||||
|
}
|
||||||
|
return profile
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
|
||||||
|
if headers == nil {
|
||||||
|
return claudeDeviceProfile{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
||||||
|
version, ok := parseClaudeCLIVersion(userAgent)
|
||||||
|
if !ok {
|
||||||
|
return claudeDeviceProfile{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
baseline := defaultClaudeDeviceProfile(cfg)
|
||||||
|
profile := claudeDeviceProfile{
|
||||||
|
UserAgent: userAgent,
|
||||||
|
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
|
||||||
|
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
|
||||||
|
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
|
||||||
|
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
|
||||||
|
Version: version,
|
||||||
|
HasVersion: true,
|
||||||
|
}
|
||||||
|
return profile, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmptyHeader(headers http.Header, name, fallback string) string {
|
||||||
|
if headers == nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if value := strings.TrimSpace(headers.Get(name)); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeDeviceProfileScopeKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||||
|
switch {
|
||||||
|
case auth != nil && strings.TrimSpace(auth.ID) != "":
|
||||||
|
return "auth:" + strings.TrimSpace(auth.ID)
|
||||||
|
case strings.TrimSpace(apiKey) != "":
|
||||||
|
return "api_key:" + strings.TrimSpace(apiKey)
|
||||||
|
default:
|
||||||
|
return "global"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeDeviceProfileCacheKey(auth *cliproxyauth.Auth, apiKey string) string {
|
||||||
|
sum := sha256.Sum256([]byte(claudeDeviceProfileScopeKey(auth, apiKey)))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func startClaudeDeviceProfileCacheCleanup() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(claudeDeviceProfileCleanupPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
purgeExpiredClaudeDeviceProfiles()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func purgeExpiredClaudeDeviceProfiles() {
|
||||||
|
now := time.Now()
|
||||||
|
claudeDeviceProfileCacheMu.Lock()
|
||||||
|
for key, entry := range claudeDeviceProfileCache {
|
||||||
|
if !entry.expire.After(now) {
|
||||||
|
delete(claudeDeviceProfileCache, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
claudeDeviceProfileCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
|
||||||
|
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
|
||||||
|
|
||||||
|
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
|
||||||
|
now := time.Now()
|
||||||
|
baseline := defaultClaudeDeviceProfile(cfg)
|
||||||
|
candidate, hasCandidate := extractClaudeDeviceProfile(headers, cfg)
|
||||||
|
if hasCandidate {
|
||||||
|
candidate = pinClaudeDeviceProfilePlatform(candidate, baseline)
|
||||||
|
}
|
||||||
|
if hasCandidate && !shouldUpgradeClaudeDeviceProfile(candidate, baseline) {
|
||||||
|
hasCandidate = false
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeDeviceProfileCacheMu.RLock()
|
||||||
|
entry, hasCached := claudeDeviceProfileCache[cacheKey]
|
||||||
|
cachedValid := hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||||
|
claudeDeviceProfileCacheMu.RUnlock()
|
||||||
|
|
||||||
|
if hasCandidate {
|
||||||
|
if claudeDeviceProfileBeforeCandidateStore != nil {
|
||||||
|
claudeDeviceProfileBeforeCandidateStore(candidate)
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeDeviceProfileCacheMu.Lock()
|
||||||
|
entry, hasCached = claudeDeviceProfileCache[cacheKey]
|
||||||
|
cachedValid = hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
|
||||||
|
if cachedValid {
|
||||||
|
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||||
|
}
|
||||||
|
if cachedValid && !shouldUpgradeClaudeDeviceProfile(candidate, entry.profile) {
|
||||||
|
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||||
|
claudeDeviceProfileCache[cacheKey] = entry
|
||||||
|
claudeDeviceProfileCacheMu.Unlock()
|
||||||
|
return entry.profile
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeDeviceProfileCache[cacheKey] = claudeDeviceProfileCacheEntry{
|
||||||
|
profile: candidate,
|
||||||
|
expire: now.Add(claudeDeviceProfileTTL),
|
||||||
|
}
|
||||||
|
claudeDeviceProfileCacheMu.Unlock()
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
if cachedValid {
|
||||||
|
claudeDeviceProfileCacheMu.Lock()
|
||||||
|
entry = claudeDeviceProfileCache[cacheKey]
|
||||||
|
if entry.expire.After(now) && entry.profile.UserAgent != "" {
|
||||||
|
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
|
||||||
|
entry.expire = now.Add(claudeDeviceProfileTTL)
|
||||||
|
claudeDeviceProfileCache[cacheKey] = entry
|
||||||
|
claudeDeviceProfileCacheMu.Unlock()
|
||||||
|
return entry.profile
|
||||||
|
}
|
||||||
|
claudeDeviceProfileCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return baseline
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, headerName := range []string{
|
||||||
|
"User-Agent",
|
||||||
|
"X-Stainless-Package-Version",
|
||||||
|
"X-Stainless-Runtime-Version",
|
||||||
|
"X-Stainless-Os",
|
||||||
|
"X-Stainless-Arch",
|
||||||
|
} {
|
||||||
|
r.Header.Del(headerName)
|
||||||
|
}
|
||||||
|
r.Header.Set("User-Agent", profile.UserAgent)
|
||||||
|
r.Header.Set("X-Stainless-Package-Version", profile.PackageVersion)
|
||||||
|
r.Header.Set("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||||
|
r.Header.Set("X-Stainless-Os", profile.OS)
|
||||||
|
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profile := defaultClaudeDeviceProfile(cfg)
|
||||||
|
miscEnsure := func(name, fallback string) {
|
||||||
|
if strings.TrimSpace(r.Header.Get(name)) != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(ginHeaders.Get(name)) != "" {
|
||||||
|
r.Header.Set(name, strings.TrimSpace(ginHeaders.Get(name)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Header.Set(name, fallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
miscEnsure("X-Stainless-Runtime-Version", profile.RuntimeVersion)
|
||||||
|
miscEnsure("X-Stainless-Package-Version", profile.PackageVersion)
|
||||||
|
miscEnsure("X-Stainless-Os", mapStainlessOS())
|
||||||
|
miscEnsure("X-Stainless-Arch", mapStainlessArch())
|
||||||
|
|
||||||
|
// Legacy mode preserves per-auth custom header overrides. By the time we get
|
||||||
|
// here, ApplyCustomHeadersFromAttrs has already populated r.Header.
|
||||||
|
if strings.TrimSpace(r.Header.Get("User-Agent")) != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientUA := ""
|
||||||
|
if ginHeaders != nil {
|
||||||
|
clientUA = strings.TrimSpace(ginHeaders.Get("User-Agent"))
|
||||||
|
}
|
||||||
|
if isClaudeCodeClient(clientUA) {
|
||||||
|
r.Header.Set("User-Agent", clientUA)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Header.Set("User-Agent", profile.UserAgent)
|
||||||
|
}
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -255,7 +254,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
data,
|
data,
|
||||||
¶m,
|
¶m,
|
||||||
)
|
)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -443,7 +442,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
¶m,
|
¶m,
|
||||||
)
|
)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
@@ -561,7 +560,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
@@ -767,36 +766,6 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
|
||||||
func mapStainlessOS() string {
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
return "MacOS"
|
|
||||||
case "windows":
|
|
||||||
return "Windows"
|
|
||||||
case "linux":
|
|
||||||
return "Linux"
|
|
||||||
case "freebsd":
|
|
||||||
return "FreeBSD"
|
|
||||||
default:
|
|
||||||
return "Other::" + runtime.GOOS
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
|
||||||
func mapStainlessArch() string {
|
|
||||||
switch runtime.GOARCH {
|
|
||||||
case "amd64":
|
|
||||||
return "x64"
|
|
||||||
case "arm64":
|
|
||||||
return "arm64"
|
|
||||||
case "386":
|
|
||||||
return "x86"
|
|
||||||
default:
|
|
||||||
return "other::" + runtime.GOARCH
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||||
hdrDefault := func(cfgVal, fallback string) string {
|
hdrDefault := func(cfgVal, fallback string) string {
|
||||||
if cfgVal != "" {
|
if cfgVal != "" {
|
||||||
@@ -824,6 +793,11 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||||
ginHeaders = ginCtx.Request.Header
|
ginHeaders = ginCtx.Request.Header
|
||||||
}
|
}
|
||||||
|
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
|
||||||
|
var deviceProfile claudeDeviceProfile
|
||||||
|
if stabilizeDeviceProfile {
|
||||||
|
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||||
@@ -867,25 +841,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||||
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||||
// For User-Agent, only forward the client's header if it's already a Claude Code client.
|
|
||||||
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
|
|
||||||
// to avoid leaking the real client identity during cloaking.
|
|
||||||
clientUA := ""
|
|
||||||
if ginHeaders != nil {
|
|
||||||
clientUA = ginHeaders.Get("User-Agent")
|
|
||||||
}
|
|
||||||
if isClaudeCodeClient(clientUA) {
|
|
||||||
r.Header.Set("User-Agent", clientUA)
|
|
||||||
} else {
|
|
||||||
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
|
|
||||||
}
|
|
||||||
r.Header.Set("Connection", "keep-alive")
|
r.Header.Set("Connection", "keep-alive")
|
||||||
if stream {
|
if stream {
|
||||||
r.Header.Set("Accept", "text/event-stream")
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
@@ -897,13 +855,19 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
|||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||||
}
|
}
|
||||||
// Keep OS/Arch mapping dynamic (not configurable).
|
// Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
|
||||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
// to the configured baseline while still allowing newer official
|
||||||
|
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
|
||||||
var attrs map[string]string
|
var attrs map[string]string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
attrs = auth.Attributes
|
attrs = auth.Attributes
|
||||||
}
|
}
|
||||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||||
|
if stabilizeDeviceProfile {
|
||||||
|
applyClaudeDeviceProfileHeaders(r, deviceProfile)
|
||||||
|
} else {
|
||||||
|
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
|
||||||
|
}
|
||||||
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||||
// may override it with a user-configured value. Compressed SSE breaks the line
|
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||||
// scanner regardless of user preference, so this is non-negotiable for streams.
|
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||||
@@ -1260,7 +1224,8 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
|||||||
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
|
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
|
||||||
partJSON := part.Raw
|
partJSON := part.Raw
|
||||||
if !part.Get("cache_control").Exists() {
|
if !part.Get("cache_control").Exists() {
|
||||||
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
|
updated, _ := sjson.SetBytes([]byte(partJSON), "cache_control.type", "ephemeral")
|
||||||
|
partJSON = string(updated)
|
||||||
}
|
}
|
||||||
result += "," + partJSON
|
result += "," + partJSON
|
||||||
}
|
}
|
||||||
@@ -1268,7 +1233,8 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
|||||||
})
|
})
|
||||||
} else if system.Type == gjson.String && system.String() != "" {
|
} else if system.Type == gjson.String && system.String() != "" {
|
||||||
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
|
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
|
||||||
partJSON, _ = sjson.Set(partJSON, "text", system.String())
|
updated, _ := sjson.SetBytes([]byte(partJSON), "text", system.String())
|
||||||
|
partJSON = string(updated)
|
||||||
result += "," + partJSON
|
result += "," + partJSON
|
||||||
}
|
}
|
||||||
result += "]"
|
result += "]"
|
||||||
|
|||||||
@@ -8,8 +8,11 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -19,6 +22,587 @@ import (
|
|||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func resetClaudeDeviceProfileCache() {
|
||||||
|
claudeDeviceProfileCacheMu.Lock()
|
||||||
|
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||||
|
claudeDeviceProfileCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ginReq := httptest.NewRequest(http.MethodPost, "http://localhost/v1/messages", nil)
|
||||||
|
ginReq.Header = incoming.Clone()
|
||||||
|
ginCtx.Request = ginReq
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
|
||||||
|
return req.WithContext(context.WithValue(req.Context(), "gin", ginCtx))
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertClaudeFingerprint(t *testing.T, headers http.Header, userAgent, pkgVersion, runtimeVersion, osName, arch string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if got := headers.Get("User-Agent"); got != userAgent {
|
||||||
|
t.Fatalf("User-Agent = %q, want %q", got, userAgent)
|
||||||
|
}
|
||||||
|
if got := headers.Get("X-Stainless-Package-Version"); got != pkgVersion {
|
||||||
|
t.Fatalf("X-Stainless-Package-Version = %q, want %q", got, pkgVersion)
|
||||||
|
}
|
||||||
|
if got := headers.Get("X-Stainless-Runtime-Version"); got != runtimeVersion {
|
||||||
|
t.Fatalf("X-Stainless-Runtime-Version = %q, want %q", got, runtimeVersion)
|
||||||
|
}
|
||||||
|
if got := headers.Get("X-Stainless-Os"); got != osName {
|
||||||
|
t.Fatalf("X-Stainless-Os = %q, want %q", got, osName)
|
||||||
|
}
|
||||||
|
if got := headers.Get("X-Stainless-Arch"); got != arch {
|
||||||
|
t.Fatalf("X-Stainless-Arch = %q, want %q", got, arch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
stabilize := true
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||||
|
PackageVersion: "0.80.0",
|
||||||
|
RuntimeVersion: "v24.5.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
Timeout: "900",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-baseline",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-baseline",
|
||||||
|
"header:User-Agent": "evil-client/9.9",
|
||||||
|
"header:X-Stainless-Os": "Linux",
|
||||||
|
"header:X-Stainless-Arch": "x64",
|
||||||
|
"header:X-Stainless-Package-Version": "9.9.9",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
incoming := http.Header{
|
||||||
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := newClaudeHeaderTestRequest(t, incoming)
|
||||||
|
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
|
||||||
|
|
||||||
|
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||||
|
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
|
||||||
|
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_TracksHighestClaudeCLIFingerprint(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
stabilize := true
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||||
|
PackageVersion: "0.70.0",
|
||||||
|
RuntimeVersion: "v22.0.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-upgrade",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-upgrade",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
firstReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(firstReq, auth, "key-upgrade", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
|
||||||
|
|
||||||
|
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"lobe-chat/1.0"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||||
|
"X-Stainless-Os": []string{"Windows"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(thirdPartyReq, auth, "key-upgrade", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
|
||||||
|
|
||||||
|
higherReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||||
|
"X-Stainless-Os": []string{"MacOS"},
|
||||||
|
"X-Stainless-Arch": []string{"arm64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(higherReq, auth, "key-upgrade", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, higherReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
|
||||||
|
|
||||||
|
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.73.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
|
||||||
|
"X-Stainless-Os": []string{"Windows"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(lowerReq, auth, "key-upgrade", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_DoesNotDowngradeConfiguredBaselineOnFirstClaudeClient(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
stabilize := true
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||||
|
PackageVersion: "0.80.0",
|
||||||
|
RuntimeVersion: "v24.5.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-baseline-floor",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-baseline-floor",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
olderClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(olderClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, olderClaudeReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||||
|
|
||||||
|
newerClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.81.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(newerClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, newerClaudeReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_UpgradesCachedSoftwareFingerprintWhenBaselineAdvances(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
stabilize := true
|
||||||
|
|
||||||
|
oldCfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||||
|
PackageVersion: "0.80.0",
|
||||||
|
RuntimeVersion: "v24.5.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.77 (external, cli)",
|
||||||
|
PackageVersion: "0.87.0",
|
||||||
|
RuntimeVersion: "v24.8.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-baseline-reload",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-baseline-reload",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.81.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(officialReq, auth, "key-baseline-reload", false, nil, oldCfg)
|
||||||
|
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
|
||||||
|
|
||||||
|
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(thirdPartyReq, auth, "key-baseline-reload", false, nil, newCfg)
|
||||||
|
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_LearnsOfficialFingerprintAfterCustomBaselineFallback(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
stabilize := true
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "my-gateway/1.0",
|
||||||
|
PackageVersion: "custom-pkg",
|
||||||
|
RuntimeVersion: "custom-runtime",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-custom-baseline-learning",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-custom-baseline-learning",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(thirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, thirdPartyReq.Header, "my-gateway/1.0", "custom-pkg", "custom-runtime", "MacOS", "arm64")
|
||||||
|
|
||||||
|
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.87.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(officialReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||||
|
|
||||||
|
postLearningThirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(postLearningThirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, postLearningThirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
stabilize := true
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||||
|
PackageVersion: "0.70.0",
|
||||||
|
RuntimeVersion: "v22.0.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-racy-upgrade",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-racy-upgrade",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
lowPaused := make(chan struct{})
|
||||||
|
releaseLow := make(chan struct{})
|
||||||
|
var pauseOnce sync.Once
|
||||||
|
var releaseOnce sync.Once
|
||||||
|
|
||||||
|
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
|
||||||
|
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pauseOnce.Do(func() { close(lowPaused) })
|
||||||
|
<-releaseLow
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
claudeDeviceProfileBeforeCandidateStore = nil
|
||||||
|
releaseOnce.Do(func() { close(releaseLow) })
|
||||||
|
})
|
||||||
|
|
||||||
|
lowResultCh := make(chan claudeDeviceProfile, 1)
|
||||||
|
go func() {
|
||||||
|
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
}, cfg)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-lowPaused:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for lower candidate to pause before storing")
|
||||||
|
}
|
||||||
|
|
||||||
|
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||||
|
"X-Stainless-Os": []string{"MacOS"},
|
||||||
|
"X-Stainless-Arch": []string{"arm64"},
|
||||||
|
}, cfg)
|
||||||
|
releaseOnce.Do(func() { close(releaseLow) })
|
||||||
|
|
||||||
|
select {
|
||||||
|
case lowResult := <-lowResultCh:
|
||||||
|
if lowResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||||
|
t.Fatalf("lowResult.UserAgent = %q, want %q", lowResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||||
|
}
|
||||||
|
if lowResult.PackageVersion != "0.75.0" {
|
||||||
|
t.Fatalf("lowResult.PackageVersion = %q, want %q", lowResult.PackageVersion, "0.75.0")
|
||||||
|
}
|
||||||
|
if lowResult.OS != "MacOS" || lowResult.Arch != "arm64" {
|
||||||
|
t.Fatalf("lowResult platform = %s/%s, want %s/%s", lowResult.OS, lowResult.Arch, "MacOS", "arm64")
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for lower candidate result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if highResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||||
|
t.Fatalf("highResult.UserAgent = %q, want %q", highResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||||
|
}
|
||||||
|
if highResult.OS != "MacOS" || highResult.Arch != "arm64" {
|
||||||
|
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
|
||||||
|
}
|
||||||
|
|
||||||
|
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||||
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
|
}, cfg)
|
||||||
|
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||||
|
t.Fatalf("cached.UserAgent = %q, want %q", cached.UserAgent, "claude-cli/2.1.63 (external, cli)")
|
||||||
|
}
|
||||||
|
if cached.PackageVersion != "0.75.0" {
|
||||||
|
t.Fatalf("cached.PackageVersion = %q, want %q", cached.PackageVersion, "0.75.0")
|
||||||
|
}
|
||||||
|
if cached.OS != "MacOS" || cached.Arch != "arm64" {
|
||||||
|
t.Fatalf("cached platform = %s/%s, want %s/%s", cached.OS, cached.Arch, "MacOS", "arm64")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_ThirdPartyBaselineThenOfficialUpgradeKeepsPinnedPlatform(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
stabilize := true
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.70 (external, cli)",
|
||||||
|
PackageVersion: "0.80.0",
|
||||||
|
RuntimeVersion: "v24.5.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-third-party-then-official",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-third-party-then-official",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(thirdPartyReq, auth, "key-third-party-then-official", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
||||||
|
|
||||||
|
officialReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.87.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(officialReq, auth, "key-third-party-then-official", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_DisableDeviceProfileStabilization(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
|
||||||
|
stabilize := false
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||||
|
PackageVersion: "0.70.0",
|
||||||
|
RuntimeVersion: "v22.0.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-disable-stability",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-disable-stability",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
firstReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(firstReq, auth, "key-disable-stability", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "Linux", "x64")
|
||||||
|
|
||||||
|
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"lobe-chat/1.0"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.10.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
|
||||||
|
"X-Stainless-Os": []string{"Windows"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(thirdPartyReq, auth, "key-disable-stability", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.60 (external, cli)", "0.10.0", "v18.0.0", "Windows", "x64")
|
||||||
|
|
||||||
|
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.73.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
|
||||||
|
"X-Stainless-Os": []string{"Windows"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(lowerReq, auth, "key-disable-stability", false, nil, cfg)
|
||||||
|
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.61 (external, cli)", "0.73.0", "v24.2.0", "Windows", "x64")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_LegacyModePreservesConfiguredUserAgentOverrideForClaudeClients(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
|
||||||
|
stabilize := false
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||||
|
PackageVersion: "0.70.0",
|
||||||
|
RuntimeVersion: "v22.0.0",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-legacy-ua-override",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-legacy-ua-override",
|
||||||
|
"header:User-Agent": "config-ua/1.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||||
|
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||||
|
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||||
|
"X-Stainless-Os": []string{"Linux"},
|
||||||
|
"X-Stainless-Arch": []string{"x64"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(req, auth, "key-legacy-ua-override", false, nil, cfg)
|
||||||
|
|
||||||
|
assertClaudeFingerprint(t, req.Header, "config-ua/1.0", "0.74.0", "v24.3.0", "Linux", "x64")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
|
||||||
|
stabilize := false
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||||
|
PackageVersion: "0.70.0",
|
||||||
|
RuntimeVersion: "v22.0.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
StabilizeDeviceProfile: &stabilize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-legacy-runtime-os-arch",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-legacy-runtime-os-arch",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
|
||||||
|
|
||||||
|
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
|
||||||
|
resetClaudeDeviceProfileCache()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
|
||||||
|
UserAgent: "claude-cli/2.1.60 (external, cli)",
|
||||||
|
PackageVersion: "0.70.0",
|
||||||
|
RuntimeVersion: "v22.0.0",
|
||||||
|
OS: "MacOS",
|
||||||
|
Arch: "arm64",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-unset-runtime-os-arch",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "key-unset-runtime-os-arch",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := newClaudeHeaderTestRequest(t, http.Header{
|
||||||
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
|
})
|
||||||
|
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
|
||||||
|
|
||||||
|
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
|
||||||
|
if claudeDeviceProfileStabilizationEnabled(nil) {
|
||||||
|
t.Fatal("expected nil config to default to disabled stabilization")
|
||||||
|
}
|
||||||
|
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
|
||||||
|
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||||
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
|
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
|
||||||
out := applyClaudeToolPrefix(input, "proxy_")
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
@@ -842,8 +1426,8 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity
|
|||||||
executor := NewClaudeExecutor(&config.Config{})
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
// Inject Accept-Encoding via the custom header attribute mechanism.
|
// Inject Accept-Encoding via the custom header attribute mechanism.
|
||||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
"api_key": "key-123",
|
"api_key": "key-123",
|
||||||
"base_url": server.URL,
|
"base_url": server.URL,
|
||||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||||
}}
|
}}
|
||||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|||||||
343
internal/runtime/executor/codebuddy_executor.go
Normal file
343
internal/runtime/executor/codebuddy_executor.go
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codeBuddyChatPath = "/v2/chat/completions"
|
||||||
|
codeBuddyAuthType = "codebuddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CodeBuddyExecutor handles requests to the CodeBuddy API.
|
||||||
|
type CodeBuddyExecutor struct {
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCodeBuddyExecutor creates a new CodeBuddy executor instance.
|
||||||
|
func NewCodeBuddyExecutor(cfg *config.Config) *CodeBuddyExecutor {
|
||||||
|
return &CodeBuddyExecutor{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier returns the unique identifier for this executor.
|
||||||
|
func (e *CodeBuddyExecutor) Identifier() string { return codeBuddyAuthType }
|
||||||
|
|
||||||
|
// codeBuddyCredentials extracts the access token and domain from auth metadata.
|
||||||
|
func codeBuddyCredentials(auth *cliproxyauth.Auth) (accessToken, userID, domain string) {
|
||||||
|
if auth == nil {
|
||||||
|
return "", "", ""
|
||||||
|
}
|
||||||
|
accessToken = metaStringValue(auth.Metadata, "access_token")
|
||||||
|
userID = metaStringValue(auth.Metadata, "user_id")
|
||||||
|
domain = metaStringValue(auth.Metadata, "domain")
|
||||||
|
if domain == "" {
|
||||||
|
domain = codebuddy.DefaultDomain
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareRequest prepares the HTTP request before execution.
|
||||||
|
func (e *CodeBuddyExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||||
|
if accessToken == "" {
|
||||||
|
return fmt.Errorf("codebuddy: missing access token")
|
||||||
|
}
|
||||||
|
e.applyHeaders(req, accessToken, userID, domain)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpRequest executes a raw HTTP request.
|
||||||
|
func (e *CodeBuddyExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute performs a non-streaming request.
|
||||||
|
func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||||
|
if accessToken == "" {
|
||||||
|
return resp, fmt.Errorf("codebuddy: missing access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("openai")
|
||||||
|
|
||||||
|
originalPayloadSource := req.Payload
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayloadSource = opts.OriginalRequest
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
|
||||||
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := codebuddy.BaseURL + codeBuddyChatPath
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: translated,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(httpResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||||
|
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
|
var param any
|
||||||
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteStream performs a streaming request.
|
||||||
|
func (e *CodeBuddyExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||||
|
if accessToken == "" {
|
||||||
|
return nil, fmt.Errorf("codebuddy: missing access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("openai")
|
||||||
|
|
||||||
|
originalPayloadSource := req.Payload
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayloadSource = opts.OriginalRequest
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
|
||||||
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := codebuddy.BaseURL + codeBuddyChatPath
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||||
|
httpReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: translated,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
httpResp.Body.Close()
|
||||||
|
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy executor: close stream body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
|
scanner.Buffer(nil, maxScannerBufferSize)
|
||||||
|
var param any
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||||
|
for i := range chunks {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
}
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &cliproxyexecutor.StreamResult{
|
||||||
|
Headers: httpResp.Header.Clone(),
|
||||||
|
Chunks: out,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh exchanges the CodeBuddy refresh token for a new access token.
|
||||||
|
func (e *CodeBuddyExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
|
if auth == nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: missing auth")
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := metaStringValue(auth.Metadata, "refresh_token")
|
||||||
|
if refreshToken == "" {
|
||||||
|
log.Debugf("codebuddy executor: no refresh token available, skipping refresh")
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||||
|
|
||||||
|
authSvc := codebuddy.NewCodeBuddyAuth(e.cfg)
|
||||||
|
storage, err := authSvc.RefreshToken(ctx, accessToken, refreshToken, userID, domain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: token refresh failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := auth.Clone()
|
||||||
|
updated.Metadata["access_token"] = storage.AccessToken
|
||||||
|
if storage.RefreshToken != "" {
|
||||||
|
updated.Metadata["refresh_token"] = storage.RefreshToken
|
||||||
|
}
|
||||||
|
updated.Metadata["expires_in"] = storage.ExpiresIn
|
||||||
|
updated.Metadata["domain"] = storage.Domain
|
||||||
|
if storage.UserID != "" {
|
||||||
|
updated.Metadata["user_id"] = storage.UserID
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
updated.UpdatedAt = now
|
||||||
|
updated.LastRefreshedAt = now
|
||||||
|
|
||||||
|
return updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountTokens is not supported for CodeBuddy.
|
||||||
|
func (e *CodeBuddyExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, fmt.Errorf("codebuddy: count tokens not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyHeaders sets required headers for CodeBuddy API requests.
|
||||||
|
func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID, domain string) {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("User-Agent", codebuddy.UserAgent)
|
||||||
|
req.Header.Set("X-User-Id", userID)
|
||||||
|
req.Header.Set("X-Domain", domain)
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
req.Header.Set("X-IDE-Type", "CLI")
|
||||||
|
req.Header.Set("X-IDE-Name", "CLI")
|
||||||
|
req.Header.Set("X-IDE-Version", "2.63.2")
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
}
|
||||||
@@ -28,8 +28,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
codexClientVersion = "0.101.0"
|
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||||
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
codexOriginator = "codex_cli_rs"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dataTag = []byte("data:")
|
var dataTag = []byte("data:")
|
||||||
@@ -122,7 +122,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyCodexHeaders(httpReq, auth, apiKey, true)
|
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
|
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||||
@@ -226,7 +226,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyCodexHeaders(httpReq, auth, apiKey, false)
|
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -273,7 +273,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,7 +321,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyCodexHeaders(httpReq, auth, apiKey, true)
|
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -387,7 +387,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
|
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
@@ -432,7 +432,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
|
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
|
||||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
|
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
|
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
|
||||||
@@ -636,7 +636,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
|||||||
return httpReq, nil
|
return httpReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool) {
|
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
|
||||||
r.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
r.Header.Set("Authorization", "Bearer "+token)
|
r.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
|
||||||
@@ -645,9 +645,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
|||||||
ginHeaders = ginCtx.Request.Header
|
ginHeaders = ginCtx.Request.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
|
||||||
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
|
||||||
|
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||||
|
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
r.Header.Set("Accept", "text/event-stream")
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
@@ -662,8 +665,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
|||||||
isAPIKey = true
|
isAPIKey = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
|
||||||
|
r.Header.Set("Originator", originator)
|
||||||
|
} else if !isAPIKey {
|
||||||
|
r.Header.Set("Originator", codexOriginator)
|
||||||
|
}
|
||||||
if !isAPIKey {
|
if !isAPIKey {
|
||||||
r.Header.Set("Originator", "codex_cli_rs")
|
|
||||||
if auth != nil && auth.Metadata != nil {
|
if auth != nil && auth.Metadata != nil {
|
||||||
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
||||||
r.Header.Set("Chatgpt-Account-Id", accountID)
|
r.Header.Set("Chatgpt-Account-Id", accountID)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -190,7 +191,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
|
|
||||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -342,7 +343,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: out}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -385,7 +386,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
}
|
}
|
||||||
|
|
||||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -591,7 +592,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
line := encodeCodexWebsocketAsSSE(payload)
|
line := encodeCodexWebsocketAsSSE(payload)
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
if !send(cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}) {
|
if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) {
|
||||||
terminateReason = "context_done"
|
terminateReason = "context_done"
|
||||||
terminateErr = ctx.Err()
|
terminateErr = ctx.Err()
|
||||||
return
|
return
|
||||||
@@ -705,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
|||||||
return dialer
|
return dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedURL, errParse := url.Parse(proxyURL)
|
setting, errParse := proxyutil.Parse(proxyURL)
|
||||||
if errParse != nil {
|
if errParse != nil {
|
||||||
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse)
|
log.Errorf("codex websockets executor: %v", errParse)
|
||||||
return dialer
|
return dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
switch parsedURL.Scheme {
|
switch setting.Mode {
|
||||||
|
case proxyutil.ModeDirect:
|
||||||
|
dialer.Proxy = nil
|
||||||
|
return dialer
|
||||||
|
case proxyutil.ModeProxy:
|
||||||
|
default:
|
||||||
|
return dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
switch setting.URL.Scheme {
|
||||||
case "socks5":
|
case "socks5":
|
||||||
var proxyAuth *proxy.Auth
|
var proxyAuth *proxy.Auth
|
||||||
if parsedURL.User != nil {
|
if setting.URL.User != nil {
|
||||||
username := parsedURL.User.Username()
|
username := setting.URL.User.Username()
|
||||||
password, _ := parsedURL.User.Password()
|
password, _ := setting.URL.User.Password()
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||||
}
|
}
|
||||||
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
|
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
|
||||||
if errSOCKS5 != nil {
|
if errSOCKS5 != nil {
|
||||||
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
|
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||||
return dialer
|
return dialer
|
||||||
@@ -729,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
|||||||
return socksDialer.Dial(network, addr)
|
return socksDialer.Dial(network, addr)
|
||||||
}
|
}
|
||||||
case "http", "https":
|
case "http", "https":
|
||||||
dialer.Proxy = http.ProxyURL(parsedURL)
|
dialer.Proxy = http.ProxyURL(setting.URL)
|
||||||
default:
|
default:
|
||||||
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme)
|
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
return dialer
|
return dialer
|
||||||
@@ -787,7 +797,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
|||||||
return rawJSON, headers
|
return rawJSON, headers
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string) http.Header {
|
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
|
||||||
if headers == nil {
|
if headers == nil {
|
||||||
headers = http.Header{}
|
headers = http.Header{}
|
||||||
}
|
}
|
||||||
@@ -800,12 +810,14 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
|||||||
ginHeaders = ginCtx.Request.Header
|
ginHeaders = ginCtx.Request.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-beta-features", "")
|
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
||||||
|
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
||||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
||||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
||||||
|
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
|
||||||
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
|
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
|
||||||
|
misc.EnsureHeader(headers, ginHeaders, "Version", "")
|
||||||
|
|
||||||
misc.EnsureHeader(headers, ginHeaders, "Version", codexClientVersion)
|
|
||||||
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
|
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
|
||||||
if betaHeader == "" && ginHeaders != nil {
|
if betaHeader == "" && ginHeaders != nil {
|
||||||
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
|
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
|
||||||
@@ -815,7 +827,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
|||||||
}
|
}
|
||||||
headers.Set("OpenAI-Beta", betaHeader)
|
headers.Set("OpenAI-Beta", betaHeader)
|
||||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||||
misc.EnsureHeader(headers, ginHeaders, "User-Agent", codexUserAgent)
|
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||||
|
|
||||||
isAPIKey := false
|
isAPIKey := false
|
||||||
if auth != nil && auth.Attributes != nil {
|
if auth != nil && auth.Attributes != nil {
|
||||||
@@ -823,8 +835,12 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
|||||||
isAPIKey = true
|
isAPIKey = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
|
||||||
|
headers.Set("Originator", originator)
|
||||||
|
} else if !isAPIKey {
|
||||||
|
headers.Set("Originator", codexOriginator)
|
||||||
|
}
|
||||||
if !isAPIKey {
|
if !isAPIKey {
|
||||||
headers.Set("Originator", "codex_cli_rs")
|
|
||||||
if auth != nil && auth.Metadata != nil {
|
if auth != nil && auth.Metadata != nil {
|
||||||
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
||||||
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
|
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
|
||||||
@@ -843,6 +859,62 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
|||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
||||||
|
if target == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(target.Get(key)) != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if source != nil {
|
||||||
|
if val := strings.TrimSpace(source.Get(key)); val != "" {
|
||||||
|
target.Set(key, val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := strings.TrimSpace(configValue); val != "" {
|
||||||
|
target.Set(key, val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if val := strings.TrimSpace(fallbackValue); val != "" {
|
||||||
|
target.Set(key, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
||||||
|
if target == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(target.Get(key)) != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if val := strings.TrimSpace(configValue); val != "" {
|
||||||
|
target.Set(key, val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if source != nil {
|
||||||
|
if val := strings.TrimSpace(source.Get(key)); val != "" {
|
||||||
|
target.Set(key, val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := strings.TrimSpace(fallbackValue); val != "" {
|
||||||
|
target.Set(key, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type statusErrWithHeaders struct {
|
type statusErrWithHeaders struct {
|
||||||
statusErr
|
statusErr
|
||||||
headers http.Header
|
headers http.Header
|
||||||
|
|||||||
@@ -3,8 +3,13 @@ package executor
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,9 +33,259 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "")
|
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||||
|
|
||||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||||
}
|
}
|
||||||
|
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||||
|
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||||
|
}
|
||||||
|
if got := headers.Get("Version"); got != "" {
|
||||||
|
t.Fatalf("Version = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||||
|
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := headers.Get("X-Codex-Turn-Metadata"); got != "" {
|
||||||
|
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := headers.Get("X-Client-Request-Id"); got != "" {
|
||||||
|
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Metadata: map[string]any{"email": "user@example.com"},
|
||||||
|
}
|
||||||
|
ctx := contextWithGinHeaders(map[string]string{
|
||||||
|
"Originator": "Codex Desktop",
|
||||||
|
"Version": "0.115.0-alpha.27",
|
||||||
|
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
|
||||||
|
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
|
||||||
|
})
|
||||||
|
|
||||||
|
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
|
||||||
|
|
||||||
|
if got := headers.Get("Originator"); got != "Codex Desktop" {
|
||||||
|
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
|
||||||
|
}
|
||||||
|
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
|
||||||
|
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
|
||||||
|
}
|
||||||
|
if got := headers.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
|
||||||
|
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
|
||||||
|
}
|
||||||
|
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
|
||||||
|
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||||
|
UserAgent: "my-codex-client/1.0",
|
||||||
|
BetaFeatures: "feature-a,feature-b",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Metadata: map[string]any{"email": "user@example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
|
||||||
|
|
||||||
|
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
|
||||||
|
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
|
||||||
|
}
|
||||||
|
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
|
||||||
|
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
|
||||||
|
}
|
||||||
|
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||||
|
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||||
|
UserAgent: "config-ua",
|
||||||
|
BetaFeatures: "config-beta",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Metadata: map[string]any{"email": "user@example.com"},
|
||||||
|
}
|
||||||
|
ctx := contextWithGinHeaders(map[string]string{
|
||||||
|
"User-Agent": "client-ua",
|
||||||
|
"X-Codex-Beta-Features": "client-beta",
|
||||||
|
})
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("User-Agent", "existing-ua")
|
||||||
|
headers.Set("X-Codex-Beta-Features", "existing-beta")
|
||||||
|
|
||||||
|
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
|
||||||
|
|
||||||
|
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
|
||||||
|
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
|
||||||
|
}
|
||||||
|
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
|
||||||
|
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||||
|
UserAgent: "config-ua",
|
||||||
|
BetaFeatures: "config-beta",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Metadata: map[string]any{"email": "user@example.com"},
|
||||||
|
}
|
||||||
|
ctx := contextWithGinHeaders(map[string]string{
|
||||||
|
"User-Agent": "client-ua",
|
||||||
|
"X-Codex-Beta-Features": "client-beta",
|
||||||
|
})
|
||||||
|
|
||||||
|
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
|
||||||
|
|
||||||
|
if got := headers.Get("User-Agent"); got != "config-ua" {
|
||||||
|
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
||||||
|
}
|
||||||
|
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
|
||||||
|
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||||
|
UserAgent: "config-ua",
|
||||||
|
BetaFeatures: "config-beta",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Attributes: map[string]string{"api_key": "sk-test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
|
||||||
|
|
||||||
|
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||||
|
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||||
|
}
|
||||||
|
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||||
|
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error = %v", err)
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
CodexHeaderDefaults: config.CodexHeaderDefaults{
|
||||||
|
UserAgent: "config-ua",
|
||||||
|
BetaFeatures: "config-beta",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Metadata: map[string]any{"email": "user@example.com"},
|
||||||
|
}
|
||||||
|
req = req.WithContext(contextWithGinHeaders(map[string]string{
|
||||||
|
"User-Agent": "client-ua",
|
||||||
|
}))
|
||||||
|
|
||||||
|
applyCodexHeaders(req, auth, "oauth-token", true, cfg)
|
||||||
|
|
||||||
|
if got := req.Header.Get("User-Agent"); got != "config-ua" {
|
||||||
|
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
||||||
|
}
|
||||||
|
if got := req.Header.Get("x-codex-beta-features"); got != "" {
|
||||||
|
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error = %v", err)
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Metadata: map[string]any{"email": "user@example.com"},
|
||||||
|
}
|
||||||
|
req = req.WithContext(contextWithGinHeaders(map[string]string{
|
||||||
|
"Originator": "Codex Desktop",
|
||||||
|
"Version": "0.115.0-alpha.27",
|
||||||
|
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
|
||||||
|
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
|
||||||
|
}))
|
||||||
|
|
||||||
|
applyCodexHeaders(req, auth, "oauth-token", true, nil)
|
||||||
|
|
||||||
|
if got := req.Header.Get("Originator"); got != "Codex Desktop" {
|
||||||
|
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
|
||||||
|
}
|
||||||
|
if got := req.Header.Get("Version"); got != "0.115.0-alpha.27" {
|
||||||
|
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
|
||||||
|
}
|
||||||
|
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
|
||||||
|
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
|
||||||
|
}
|
||||||
|
if got := req.Header.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
|
||||||
|
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexHeadersDoesNotInjectClientOnlyHeadersByDefault(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexHeaders(req, nil, "oauth-token", true, nil)
|
||||||
|
|
||||||
|
if got := req.Header.Get("Version"); got != "" {
|
||||||
|
t.Fatalf("Version = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != "" {
|
||||||
|
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := req.Header.Get("X-Client-Request-Id"); got != "" {
|
||||||
|
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextWithGinHeaders(headers map[string]string) context.Context {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ginCtx.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
ginCtx.Request.Header = make(http.Header, len(headers))
|
||||||
|
for key, value := range headers {
|
||||||
|
ginCtx.Request.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
return context.WithValue(context.Background(), "gin", ginCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dialer := newProxyAwareWebsocketDialer(
|
||||||
|
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||||
|
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if dialer.Proxy != nil {
|
||||||
|
t.Fatal("expected websocket proxy function to be nil for direct mode")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
1719
internal/runtime/executor/cursor_executor.go
Normal file
1719
internal/runtime/executor/cursor_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -224,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -401,14 +401,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
if bytes.HasPrefix(line, dataTag) {
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||||
for i := range segments {
|
for i := range segments {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||||
for i := range segments {
|
for i := range segments {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
@@ -430,12 +430,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
var param any
|
var param any
|
||||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||||
for i := range segments {
|
for i := range segments {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||||
}
|
}
|
||||||
|
|
||||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||||
for i := range segments {
|
for i := range segments {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||||
}
|
}
|
||||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||||
|
|
||||||
@@ -544,7 +544,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
lastStatus = resp.StatusCode
|
lastStatus = resp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
@@ -811,18 +811,18 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
|||||||
|
|
||||||
if !hasInlineData {
|
if !hasInlineData {
|
||||||
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
||||||
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
|
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
|
||||||
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||||
newPartsJson := `[]`
|
newPartsJson := []byte(`[]`)
|
||||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
|
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
|
||||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
|
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
|
||||||
|
|
||||||
parts := contentArray[0].Get("parts").Array()
|
parts := contentArray[0].Get("parts").Array()
|
||||||
for j := 0; j < len(parts); j++ {
|
for j := 0; j < len(parts); j++ {
|
||||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
|
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
|
||||||
}
|
}
|
||||||
|
|
||||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", newPartsJson)
|
||||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,12 +321,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
@@ -415,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
|
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||||
@@ -527,18 +527,18 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
|||||||
|
|
||||||
if !hasInlineData {
|
if !hasInlineData {
|
||||||
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
||||||
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
|
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
|
||||||
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
||||||
newPartsJson := `[]`
|
newPartsJson := []byte(`[]`)
|
||||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
|
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
|
||||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
|
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
|
||||||
|
|
||||||
parts := contentArray[0].Get("parts").Array()
|
parts := contentArray[0].Get("parts").Array()
|
||||||
for j := 0; j < len(parts); j++ {
|
for j := 0; j < len(parts); j++ {
|
||||||
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
|
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
|
||||||
}
|
}
|
||||||
|
|
||||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", newPartsJson)
|
||||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -524,7 +524,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -636,12 +636,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
@@ -760,12 +760,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
@@ -857,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||||
@@ -941,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||||
|
|||||||
@@ -221,13 +221,13 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
|
|
||||||
var param any
|
var param any
|
||||||
converted := ""
|
var converted []byte
|
||||||
if useResponses && from.String() == "claude" {
|
if useResponses && from.String() == "claude" {
|
||||||
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
||||||
} else {
|
} else {
|
||||||
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
}
|
}
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
resp = cliproxyexecutor.Response{Payload: converted}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -374,14 +374,14 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var chunks []string
|
var chunks [][]byte
|
||||||
if useResponses && from.String() == "claude" {
|
if useResponses && from.String() == "claude" {
|
||||||
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
||||||
} else {
|
} else {
|
||||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||||
}
|
}
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -577,9 +577,33 @@ func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
|
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
|
||||||
|
if info := registry.GetGlobalRegistry().GetModelInfo(baseModel, githubCopilotAuthType); info != nil {
|
||||||
|
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
|
||||||
|
}
|
||||||
|
if info := lookupGitHubCopilotStaticModelInfo(baseModel); info != nil {
|
||||||
|
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
|
||||||
|
}
|
||||||
return strings.Contains(baseModel, "codex")
|
return strings.Contains(baseModel, "codex")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
|
||||||
|
for _, info := range registry.GetStaticModelDefinitionsByChannel(githubCopilotAuthType) {
|
||||||
|
if info != nil && strings.EqualFold(info.ID, model) {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsEndpoint(endpoints []string, endpoint string) bool {
|
||||||
|
for _, item := range endpoints {
|
||||||
|
if item == endpoint {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// flattenAssistantContent converts assistant message content from array format
|
// flattenAssistantContent converts assistant message content from array format
|
||||||
// to a joined string. GitHub Copilot requires assistant content as a string;
|
// to a joined string. GitHub Copilot requires assistant content as a string;
|
||||||
// sending it as an array causes Claude models to re-answer all previous prompts.
|
// sending it as an array causes Claude models to re-answer all previous prompts.
|
||||||
@@ -653,6 +677,7 @@ func normalizeGitHubCopilotChatTools(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
||||||
|
body = stripGitHubCopilotResponsesUnsupportedFields(body)
|
||||||
input := gjson.GetBytes(body, "input")
|
input := gjson.GetBytes(body, "input")
|
||||||
if input.Exists() {
|
if input.Exists() {
|
||||||
// If input is already a string or array, keep it as-is.
|
// If input is already a string or array, keep it as-is.
|
||||||
@@ -825,6 +850,12 @@ func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
|
||||||
|
// GitHub Copilot /responses rejects service_tier, so always remove it.
|
||||||
|
body, _ = sjson.DeleteBytes(body, "service_tier")
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||||
tools := gjson.GetBytes(body, "tools")
|
tools := gjson.GetBytes(body, "tools")
|
||||||
if tools.Exists() {
|
if tools.Exists() {
|
||||||
@@ -970,7 +1001,7 @@ type githubCopilotResponsesStreamState struct {
|
|||||||
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
|
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
|
||||||
}
|
}
|
||||||
|
|
||||||
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) []byte {
|
||||||
root := gjson.ParseBytes(data)
|
root := gjson.ParseBytes(data)
|
||||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
||||||
@@ -1060,10 +1091,10 @@ func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
|||||||
} else {
|
} else {
|
||||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||||
}
|
}
|
||||||
return out
|
return []byte(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
|
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) [][]byte {
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &githubCopilotResponsesStreamState{
|
*param = &githubCopilotResponsesStreamState{
|
||||||
TextBlockIndex: -1,
|
TextBlockIndex: -1,
|
||||||
@@ -1085,7 +1116,10 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
}
|
}
|
||||||
|
|
||||||
event := gjson.GetBytes(payload, "type").String()
|
event := gjson.GetBytes(payload, "type").String()
|
||||||
results := make([]string, 0, 4)
|
results := make([][]byte, 0, 4)
|
||||||
|
appendResult := func(chunk string) {
|
||||||
|
results = append(results, []byte(chunk))
|
||||||
|
}
|
||||||
ensureMessageStart := func() {
|
ensureMessageStart := func() {
|
||||||
if state.MessageStarted {
|
if state.MessageStarted {
|
||||||
return
|
return
|
||||||
@@ -1093,7 +1127,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
|
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
|
||||||
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
|
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
|
||||||
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
|
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
|
||||||
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
|
appendResult("event: message_start\ndata: " + messageStart + "\n\n")
|
||||||
state.MessageStarted = true
|
state.MessageStarted = true
|
||||||
}
|
}
|
||||||
startTextBlockIfNeeded := func() {
|
startTextBlockIfNeeded := func() {
|
||||||
@@ -1106,7 +1140,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
}
|
}
|
||||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
|
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
|
||||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
|
||||||
state.TextBlockStarted = true
|
state.TextBlockStarted = true
|
||||||
}
|
}
|
||||||
stopTextBlockIfNeeded := func() {
|
stopTextBlockIfNeeded := func() {
|
||||||
@@ -1115,7 +1149,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
}
|
}
|
||||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
|
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
|
||||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
|
||||||
state.TextBlockStarted = false
|
state.TextBlockStarted = false
|
||||||
state.TextBlockIndex = -1
|
state.TextBlockIndex = -1
|
||||||
}
|
}
|
||||||
@@ -1145,7 +1179,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||||
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
|
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
|
||||||
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
|
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
|
||||||
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
|
appendResult("event: content_block_delta\ndata: " + contentDelta + "\n\n")
|
||||||
}
|
}
|
||||||
case "response.reasoning_summary_part.added":
|
case "response.reasoning_summary_part.added":
|
||||||
ensureMessageStart()
|
ensureMessageStart()
|
||||||
@@ -1154,7 +1188,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
state.NextContentIndex++
|
state.NextContentIndex++
|
||||||
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
||||||
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
|
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
|
||||||
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
|
appendResult("event: content_block_start\ndata: " + thinkingStart + "\n\n")
|
||||||
case "response.reasoning_summary_text.delta":
|
case "response.reasoning_summary_text.delta":
|
||||||
if state.ReasoningActive {
|
if state.ReasoningActive {
|
||||||
delta := gjson.GetBytes(payload, "delta").String()
|
delta := gjson.GetBytes(payload, "delta").String()
|
||||||
@@ -1162,14 +1196,14 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
||||||
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
|
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
|
||||||
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
|
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
|
||||||
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
|
appendResult("event: content_block_delta\ndata: " + thinkingDelta + "\n\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "response.reasoning_summary_part.done":
|
case "response.reasoning_summary_part.done":
|
||||||
if state.ReasoningActive {
|
if state.ReasoningActive {
|
||||||
thinkingStop := `{"type":"content_block_stop","index":0}`
|
thinkingStop := `{"type":"content_block_stop","index":0}`
|
||||||
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
|
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
|
||||||
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
|
appendResult("event: content_block_stop\ndata: " + thinkingStop + "\n\n")
|
||||||
state.ReasoningActive = false
|
state.ReasoningActive = false
|
||||||
}
|
}
|
||||||
case "response.output_item.added":
|
case "response.output_item.added":
|
||||||
@@ -1197,7 +1231,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
|
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
|
||||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
|
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
|
||||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
|
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
|
||||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
|
||||||
case "response.output_item.delta":
|
case "response.output_item.delta":
|
||||||
item := gjson.GetBytes(payload, "item")
|
item := gjson.GetBytes(payload, "item")
|
||||||
if item.Get("type").String() != "function_call" {
|
if item.Get("type").String() != "function_call" {
|
||||||
@@ -1217,7 +1251,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
|
||||||
case "response.function_call_arguments.delta":
|
case "response.function_call_arguments.delta":
|
||||||
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
|
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
|
||||||
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
|
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
|
||||||
@@ -1234,7 +1268,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
|
||||||
case "response.output_item.done":
|
case "response.output_item.done":
|
||||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||||
break
|
break
|
||||||
@@ -1245,7 +1279,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
}
|
}
|
||||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
|
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
|
||||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
|
||||||
case "response.completed":
|
case "response.completed":
|
||||||
ensureMessageStart()
|
ensureMessageStart()
|
||||||
stopTextBlockIfNeeded()
|
stopTextBlockIfNeeded()
|
||||||
@@ -1269,8 +1303,8 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
|
|||||||
if cachedTokens > 0 {
|
if cachedTokens > 0 {
|
||||||
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
|
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
}
|
}
|
||||||
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
|
appendResult("event: message_delta\ndata: " + messageDelta + "\n\n")
|
||||||
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
appendResult("event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||||
state.MessageStopSent = true
|
state.MessageStopSent = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
@@ -70,6 +71,29 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||||
|
t.Fatal("expected responses-only registry model to use /responses")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
clientID := "github-copilot-test-client"
|
||||||
|
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
|
||||||
|
ID: "gpt-5.4",
|
||||||
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
|
}})
|
||||||
|
defer reg.UnregisterClient(clientID)
|
||||||
|
|
||||||
|
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||||
|
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||||
@@ -132,6 +156,19 @@ func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeGitHubCopilotResponsesInput_StripsServiceTier(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(`{"input":"user text","service_tier":"default"}`)
|
||||||
|
got := normalizeGitHubCopilotResponsesInput(body)
|
||||||
|
|
||||||
|
if gjson.GetBytes(got, "service_tier").Exists() {
|
||||||
|
t.Fatalf("service_tier should be removed, got %s", gjson.GetBytes(got, "service_tier").Raw)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(got, "input").String() != "user text" {
|
||||||
|
t.Fatalf("input = %q, want %q", gjson.GetBytes(got, "input").String(), "user text")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||||
|
|||||||
1374
internal/runtime/executor/gitlab_executor.go
Normal file
1374
internal/runtime/executor/gitlab_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
539
internal/runtime/executor/gitlab_executor_test.go
Normal file
539
internal/runtime/executor/gitlab_executor_test.go
Normal file
@@ -0,0 +1,539 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteUsesChatEndpoint(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != gitLabChatEndpoint {
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(`"chat response"`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"base_url": srv.URL,
|
||||||
|
"access_token": "oauth-access",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "gitlab-duo",
|
||||||
|
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "chat response" {
|
||||||
|
t.Fatalf("expected chat response, got %q", got)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("expected resolved model, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteFallsBackToCodeSuggestions(t *testing.T) {
|
||||||
|
chatCalls := 0
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case gitLabChatEndpoint:
|
||||||
|
chatCalls++
|
||||||
|
http.Error(w, "feature unavailable", http.StatusForbidden)
|
||||||
|
case gitLabCodeSuggestionsEndpoint:
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{{
|
||||||
|
"text": "fallback response",
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"base_url": srv.URL,
|
||||||
|
"personal_access_token": "glpat-token",
|
||||||
|
"auth_method": "pat",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "gitlab-duo",
|
||||||
|
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"write code"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if chatCalls != 1 {
|
||||||
|
t.Fatalf("expected chat endpoint to be tried once, got %d", chatCalls)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "fallback response" {
|
||||||
|
t.Fatalf("expected fallback response, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteUsesAnthropicGateway(t *testing.T) {
|
||||||
|
var gotAuthHeader, gotRealmHeader string
|
||||||
|
var gotPath string
|
||||||
|
var gotModel string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
gotAuthHeader = r.Header.Get("Authorization")
|
||||||
|
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||||
|
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"duo_gateway_base_url": srv.URL,
|
||||||
|
"duo_gateway_token": "gateway-token",
|
||||||
|
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "gitlab-duo",
|
||||||
|
Payload: []byte(`{
|
||||||
|
"model":"gitlab-duo",
|
||||||
|
"messages":[{"role":"user","content":[{"type":"text","text":"list files"}]}],
|
||||||
|
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}],
|
||||||
|
"max_tokens":128
|
||||||
|
}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||||
|
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||||
|
}
|
||||||
|
if gotAuthHeader != "Bearer gateway-token" {
|
||||||
|
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||||
|
}
|
||||||
|
if gotRealmHeader != "saas" {
|
||||||
|
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||||
|
}
|
||||||
|
if gotModel != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("model = %q, want claude-sonnet-4-5", gotModel)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(resp.Payload, "content.0.type").String(); got != "tool_use" {
|
||||||
|
t.Fatalf("expected tool_use response, got %q", got)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(resp.Payload, "content.0.name").String(); got != "Bash" {
|
||||||
|
t.Fatalf("expected tool name Bash, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
|
||||||
|
var gotAuthHeader, gotRealmHeader string
|
||||||
|
var gotPath string
|
||||||
|
var gotModel string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
gotAuthHeader = r.Header.Get("Authorization")
|
||||||
|
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||||
|
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from openai gateway\"}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from openai gateway\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"duo_gateway_base_url": srv.URL,
|
||||||
|
"duo_gateway_token": "gateway-token",
|
||||||
|
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||||
|
"model_provider": "openai",
|
||||||
|
"model_name": "gpt-5-codex",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "gitlab-duo",
|
||||||
|
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||||
|
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||||
|
}
|
||||||
|
if gotAuthHeader != "Bearer gateway-token" {
|
||||||
|
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||||
|
}
|
||||||
|
if gotRealmHeader != "saas" {
|
||||||
|
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||||
|
}
|
||||||
|
if gotModel != "gpt-5-codex" {
|
||||||
|
t.Fatalf("model = %q, want gpt-5-codex", gotModel)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from openai gateway" {
|
||||||
|
t.Fatalf("expected openai gateway response, got %q payload=%s", got, string(resp.Payload))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
|
||||||
|
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
|
||||||
|
var gotPath string
|
||||||
|
var gotModel string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
gotAuthHeader = r.Header.Get("Authorization")
|
||||||
|
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||||
|
gotBetaHeader = r.Header.Get("anthropic-beta")
|
||||||
|
gotUserAgent = r.Header.Get("User-Agent")
|
||||||
|
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"duo_gateway_base_url": srv.URL,
|
||||||
|
"duo_gateway_token": "gateway-token",
|
||||||
|
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "duo-chat-gpt-5-codex",
|
||||||
|
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||||
|
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||||
|
}
|
||||||
|
if gotAuthHeader != "Bearer gateway-token" {
|
||||||
|
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||||
|
}
|
||||||
|
if gotRealmHeader != "saas" {
|
||||||
|
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||||
|
}
|
||||||
|
if gotBetaHeader != gitLabContext1MBeta {
|
||||||
|
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
|
||||||
|
}
|
||||||
|
if gotUserAgent != gitLabNativeUserAgent {
|
||||||
|
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||||
|
}
|
||||||
|
if gotModel != "duo-chat-gpt-5-codex" {
|
||||||
|
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
|
||||||
|
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/oauth/token":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"access_token": "oauth-refreshed",
|
||||||
|
"refresh_token": "oauth-refresh",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "api read_user",
|
||||||
|
"created_at": 1710000000,
|
||||||
|
"expires_in": 3600,
|
||||||
|
})
|
||||||
|
case "/api/v4/code_suggestions/direct_access":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"base_url": "https://cloud.gitlab.example.com",
|
||||||
|
"token": "gateway-token",
|
||||||
|
"expires_at": 1710003600,
|
||||||
|
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "gitlab-auth.json",
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"base_url": srv.URL,
|
||||||
|
"access_token": "oauth-access",
|
||||||
|
"refresh_token": "oauth-refresh",
|
||||||
|
"oauth_client_id": "client-id",
|
||||||
|
"auth_method": "oauth",
|
||||||
|
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := exec.Refresh(context.Background(), auth)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Refresh() error = %v", err)
|
||||||
|
}
|
||||||
|
if got := updated.Metadata["access_token"]; got != "oauth-refreshed" {
|
||||||
|
t.Fatalf("expected refreshed access token, got %#v", got)
|
||||||
|
}
|
||||||
|
if got := updated.Metadata["model_name"]; got != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("expected refreshed model metadata, got %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteStreamUsesCodeSuggestionsSSE(t *testing.T) {
|
||||||
|
var gotAccept, gotStreamingHeader, gotEncoding string
|
||||||
|
var gotStreamFlag bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != gitLabCodeSuggestionsEndpoint {
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
gotAccept = r.Header.Get("Accept")
|
||||||
|
gotStreamingHeader = r.Header.Get(gitLabSSEStreamingHeader)
|
||||||
|
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||||
|
gotStreamFlag = gjson.GetBytes(readBody(t, r), "stream").Bool()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("event: stream_start\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"model\":{\"name\":\"claude-sonnet-4-5\"}}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: content_chunk\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"content\":\"hello\"}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: content_chunk\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"content\":\" world\"}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: stream_end\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {}\n\n"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"base_url": srv.URL,
|
||||||
|
"access_token": "oauth-access",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "gitlab-duo",
|
||||||
|
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := collectStreamLines(t, result)
|
||||||
|
if gotAccept != "text/event-stream" {
|
||||||
|
t.Fatalf("Accept = %q, want text/event-stream", gotAccept)
|
||||||
|
}
|
||||||
|
if gotStreamingHeader != "true" {
|
||||||
|
t.Fatalf("%s = %q, want true", gitLabSSEStreamingHeader, gotStreamingHeader)
|
||||||
|
}
|
||||||
|
if gotEncoding != "identity" {
|
||||||
|
t.Fatalf("Accept-Encoding = %q, want identity", gotEncoding)
|
||||||
|
}
|
||||||
|
if !gotStreamFlag {
|
||||||
|
t.Fatalf("expected upstream request to set stream=true")
|
||||||
|
}
|
||||||
|
if len(lines) < 4 {
|
||||||
|
t.Fatalf("expected translated stream chunks, got %d", len(lines))
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.Join(lines, "\n"), `"content":"hello"`) {
|
||||||
|
t.Fatalf("expected hello delta in stream, got %q", strings.Join(lines, "\n"))
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.Join(lines, "\n"), `"content":" world"`) {
|
||||||
|
t.Fatalf("expected world delta in stream, got %q", strings.Join(lines, "\n"))
|
||||||
|
}
|
||||||
|
last := lines[len(lines)-1]
|
||||||
|
if last != "data: [DONE]" && !strings.Contains(last, `"finish_reason":"stop"`) {
|
||||||
|
t.Fatalf("expected stream terminator, got %q", last)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
|
||||||
|
chatCalls := 0
|
||||||
|
streamCalls := 0
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case gitLabCodeSuggestionsEndpoint:
|
||||||
|
streamCalls++
|
||||||
|
http.Error(w, "feature unavailable", http.StatusForbidden)
|
||||||
|
case gitLabChatEndpoint:
|
||||||
|
chatCalls++
|
||||||
|
_, _ = w.Write([]byte(`"chat fallback response"`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"base_url": srv.URL,
|
||||||
|
"access_token": "oauth-access",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "gitlab-duo",
|
||||||
|
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := collectStreamLines(t, result)
|
||||||
|
if streamCalls != 1 {
|
||||||
|
t.Fatalf("expected streaming endpoint once, got %d", streamCalls)
|
||||||
|
}
|
||||||
|
if chatCalls != 1 {
|
||||||
|
t.Fatalf("expected chat fallback once, got %d", chatCalls)
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.Join(lines, "\n"), `"content":"chat fallback response"`) {
|
||||||
|
t.Fatalf("expected fallback content in stream, got %q", strings.Join(lines, "\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||||
|
var gotPath, gotBetaHeader, gotUserAgent string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
gotBetaHeader = r.Header.Get("Anthropic-Beta")
|
||||||
|
gotUserAgent = r.Header.Get("User-Agent")
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("event: message_start\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: content_block_start\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: content_block_delta\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from gateway\"}}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: message_delta\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: message_stop\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"duo_gateway_base_url": srv.URL,
|
||||||
|
"duo_gateway_token": "gateway-token",
|
||||||
|
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "gitlab-duo",
|
||||||
|
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"max_tokens":64}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := collectStreamLines(t, result)
|
||||||
|
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||||
|
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||||
|
}
|
||||||
|
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
|
||||||
|
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
|
||||||
|
}
|
||||||
|
if gotUserAgent != gitLabNativeUserAgent {
|
||||||
|
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
||||||
|
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectStreamLines(t *testing.T, result *cliproxyexecutor.StreamResult) []string {
|
||||||
|
t.Helper()
|
||||||
|
lines := make([]string, 0, 8)
|
||||||
|
for chunk := range result.Chunks {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||||
|
}
|
||||||
|
lines = append(lines, string(chunk.Payload))
|
||||||
|
}
|
||||||
|
return lines
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBody(t *testing.T, r *http.Request) []byte {
|
||||||
|
t.Helper()
|
||||||
|
defer func() { _ = r.Body.Close() }()
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error = %v", err)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
@@ -169,7 +169,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,7 +281,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
@@ -315,7 +315,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
usageJSON := buildOpenAIUsageJSON(count)
|
usageJSON := buildOpenAIUsageJSON(count)
|
||||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key.
|
// Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key.
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,12 +271,12 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||||
for i := range doneChunks {
|
for i := range doneChunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
|||||||
@@ -89,6 +89,13 @@ var endpointAliases = map[string]string{
|
|||||||
"cli": "amazonq",
|
"cli": "amazonq",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func enqueueTranslatedSSE(out chan<- cliproxyexecutor.StreamChunk, chunk []byte) {
|
||||||
|
if len(chunk) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: append(bytes.Clone(chunk), '\n', '\n')}
|
||||||
|
}
|
||||||
|
|
||||||
// retryConfig holds configuration for socket retry logic.
|
// retryConfig holds configuration for socket retry logic.
|
||||||
// Based on kiro2Api Python implementation patterns.
|
// Based on kiro2Api Python implementation patterns.
|
||||||
type retryConfig struct {
|
type retryConfig struct {
|
||||||
@@ -2573,9 +2580,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name)
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send tool input as delta
|
// Send tool input as delta
|
||||||
@@ -2583,18 +2588,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex)
|
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex)
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close block
|
// Close block
|
||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hasToolUses = true
|
hasToolUses = true
|
||||||
@@ -2664,9 +2665,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens)
|
msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
messageStartSent = true
|
messageStartSent = true
|
||||||
}
|
}
|
||||||
@@ -2916,9 +2915,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
|
pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lastReportedOutputTokens = currentOutputTokens
|
lastReportedOutputTokens = currentOutputTokens
|
||||||
@@ -2939,17 +2936,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
|
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -2978,18 +2971,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Send thinking delta
|
// Send thinking delta
|
||||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
|
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
accumulatedThinkingContent.WriteString(thinkingText)
|
accumulatedThinkingContent.WriteString(thinkingText)
|
||||||
}
|
}
|
||||||
@@ -2998,9 +2987,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
isThinkingBlockOpen = false
|
isThinkingBlockOpen = false
|
||||||
}
|
}
|
||||||
@@ -3029,17 +3016,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex)
|
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
accumulatedThinkingContent.WriteString(processContent)
|
accumulatedThinkingContent.WriteString(processContent)
|
||||||
}
|
}
|
||||||
@@ -3058,9 +3041,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
isThinkingBlockOpen = false
|
isThinkingBlockOpen = false
|
||||||
}
|
}
|
||||||
@@ -3071,18 +3052,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Send text delta
|
// Send text delta
|
||||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex)
|
claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Close text block before entering thinking
|
// Close text block before entering thinking
|
||||||
@@ -3090,9 +3067,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
isTextBlockOpen = false
|
isTextBlockOpen = false
|
||||||
}
|
}
|
||||||
@@ -3120,17 +3095,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex)
|
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3158,9 +3129,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
isTextBlockOpen = false
|
isTextBlockOpen = false
|
||||||
}
|
}
|
||||||
@@ -3171,9 +3140,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send input_json_delta with the tool input
|
// Send input_json_delta with the tool input
|
||||||
@@ -3186,9 +3153,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3197,9 +3162,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3239,9 +3202,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
isTextBlockOpen = false
|
isTextBlockOpen = false
|
||||||
}
|
}
|
||||||
@@ -3254,9 +3215,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3264,9 +3223,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
|
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate for token counting
|
// Accumulate for token counting
|
||||||
@@ -3298,9 +3255,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
isTextBlockOpen = false
|
isTextBlockOpen = false
|
||||||
}
|
}
|
||||||
@@ -3310,9 +3265,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tu.Input != nil {
|
if tu.Input != nil {
|
||||||
@@ -3323,9 +3276,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3333,9 +3284,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3522,9 +3471,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3609,18 +3556,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage)
|
msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage)
|
||||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam)
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send message_stop event separately
|
// Send message_stop event separately
|
||||||
msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent()
|
msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent()
|
||||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam)
|
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam)
|
||||||
for _, chunk := range sseData {
|
for _, chunk := range sseData {
|
||||||
if chunk != "" {
|
enqueueTranslatedSSE(out, chunk)
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// reporter.publish is called via defer
|
// reporter.publish is called via defer
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
// Translate response back to source format when needed
|
// Translate response back to source format when needed
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,6 +205,10 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Request usage data in the final streaming chunk so that token statistics
|
||||||
|
// are captured even when the upstream is an OpenAI-compatible provider.
|
||||||
|
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -286,7 +290,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
// Pass through translator; it yields one or more chunks for the target schema.
|
// Pass through translator; it yields one or more chunks for the target schema.
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
@@ -326,7 +330,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
|||||||
|
|
||||||
usageJSON := buildOpenAIUsageJSON(count)
|
usageJSON := buildOpenAIUsageJSON(count)
|
||||||
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil
|
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh is a no-op for API-key based compatibility providers.
|
// Refresh is a no-op for API-key based compatibility providers.
|
||||||
|
|||||||
@@ -2,17 +2,15 @@ package executor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
|
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
|
||||||
@@ -111,45 +109,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
|
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
|
||||||
func buildProxyTransport(proxyURL string) *http.Transport {
|
func buildProxyTransport(proxyURL string) *http.Transport {
|
||||||
if proxyURL == "" {
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
|
||||||
|
if errBuild != nil {
|
||||||
|
log.Errorf("%v", errBuild)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedURL, errParse := url.Parse(proxyURL)
|
|
||||||
if errParse != nil {
|
|
||||||
log.Errorf("parse proxy URL failed: %v", errParse)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var transport *http.Transport
|
|
||||||
|
|
||||||
// Handle different proxy schemes
|
|
||||||
if parsedURL.Scheme == "socks5" {
|
|
||||||
// Configure SOCKS5 proxy with optional authentication
|
|
||||||
var proxyAuth *proxy.Auth
|
|
||||||
if parsedURL.User != nil {
|
|
||||||
username := parsedURL.User.Username()
|
|
||||||
password, _ := parsedURL.User.Password()
|
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
|
||||||
}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Set up a custom transport using the SOCKS5 dialer
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
|
|
||||||
// Configure HTTP or HTTPS proxy
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
|
|
||||||
} else {
|
|
||||||
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return transport
|
return transport
|
||||||
}
|
}
|
||||||
|
|||||||
30
internal/runtime/executor/proxy_helpers_test.go
Normal file
30
internal/runtime/executor/proxy_helpers_test.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client := newProxyAwareHTTPClient(
|
||||||
|
context.Background(),
|
||||||
|
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||||
|
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
transport, ok := client.Transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
|
||||||
|
}
|
||||||
|
if transport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -305,7 +305,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -421,12 +421,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||||
for i := range doneChunks {
|
for i := range doneChunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
@@ -461,7 +461,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
|
|
||||||
usageJSON := buildOpenAIUsageJSON(count)
|
usageJSON := buildOpenAIUsageJSON(count)
|
||||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
|
|||||||
@@ -73,17 +73,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.once.Do(func() {
|
r.once.Do(func() {
|
||||||
usage.PublishRecord(ctx, usage.Record{
|
usage.PublishRecord(ctx, r.buildRecord(detail, failed))
|
||||||
Provider: r.provider,
|
|
||||||
Model: r.model,
|
|
||||||
Source: r.source,
|
|
||||||
APIKey: r.apiKey,
|
|
||||||
AuthID: r.authID,
|
|
||||||
AuthIndex: r.authIndex,
|
|
||||||
RequestedAt: r.requestedAt,
|
|
||||||
Failed: failed,
|
|
||||||
Detail: detail,
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,20 +86,39 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.once.Do(func() {
|
r.once.Do(func() {
|
||||||
usage.PublishRecord(ctx, usage.Record{
|
usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false))
|
||||||
Provider: r.provider,
|
|
||||||
Model: r.model,
|
|
||||||
Source: r.source,
|
|
||||||
APIKey: r.apiKey,
|
|
||||||
AuthID: r.authID,
|
|
||||||
AuthIndex: r.authIndex,
|
|
||||||
RequestedAt: r.requestedAt,
|
|
||||||
Failed: false,
|
|
||||||
Detail: usage.Detail{},
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
||||||
|
if r == nil {
|
||||||
|
return usage.Record{Detail: detail, Failed: failed}
|
||||||
|
}
|
||||||
|
return usage.Record{
|
||||||
|
Provider: r.provider,
|
||||||
|
Model: r.model,
|
||||||
|
Source: r.source,
|
||||||
|
APIKey: r.apiKey,
|
||||||
|
AuthID: r.authID,
|
||||||
|
AuthIndex: r.authIndex,
|
||||||
|
RequestedAt: r.requestedAt,
|
||||||
|
Latency: r.latency(),
|
||||||
|
Failed: failed,
|
||||||
|
Detail: detail,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageReporter) latency() time.Duration {
|
||||||
|
if r == nil || r.requestedAt.IsZero() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
latency := time.Since(r.requestedAt)
|
||||||
|
if latency < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return latency
|
||||||
|
}
|
||||||
|
|
||||||
func apiKeyFromContext(ctx context.Context) string {
|
func apiKeyFromContext(ctx context.Context) string {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package executor
|
package executor
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||||
|
)
|
||||||
|
|
||||||
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||||
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
||||||
@@ -41,3 +46,19 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
|
|||||||
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
|
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
||||||
|
reporter := &usageReporter{
|
||||||
|
provider: "openai",
|
||||||
|
model: "gpt-5.4",
|
||||||
|
requestedAt: time.Now().Add(-1500 * time.Millisecond),
|
||||||
|
}
|
||||||
|
|
||||||
|
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
|
||||||
|
if record.Latency < time.Second {
|
||||||
|
t.Fatalf("latency = %v, want >= 1s", record.Latency)
|
||||||
|
}
|
||||||
|
if record.Latency > 3*time.Second {
|
||||||
|
t.Fatalf("latency = %v, want <= 3s", record.Latency)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -39,35 +40,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
rawJSON := inputRawJSON
|
rawJSON := inputRawJSON
|
||||||
|
|
||||||
// system instruction
|
// system instruction
|
||||||
systemInstructionJSON := ""
|
var systemInstructionJSON []byte
|
||||||
hasSystemInstruction := false
|
hasSystemInstruction := false
|
||||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||||
if systemResult.IsArray() {
|
if systemResult.IsArray() {
|
||||||
systemResults := systemResult.Array()
|
systemResults := systemResult.Array()
|
||||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
|
||||||
for i := 0; i < len(systemResults); i++ {
|
for i := 0; i < len(systemResults); i++ {
|
||||||
systemPromptResult := systemResults[i]
|
systemPromptResult := systemResults[i]
|
||||||
systemTypePromptResult := systemPromptResult.Get("type")
|
systemTypePromptResult := systemPromptResult.Get("type")
|
||||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||||
systemPrompt := systemPromptResult.Get("text").String()
|
systemPrompt := systemPromptResult.Get("text").String()
|
||||||
partJSON := `{}`
|
partJSON := []byte(`{}`)
|
||||||
if systemPrompt != "" {
|
if systemPrompt != "" {
|
||||||
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
|
partJSON, _ = sjson.SetBytes(partJSON, "text", systemPrompt)
|
||||||
}
|
}
|
||||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
|
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", partJSON)
|
||||||
hasSystemInstruction = true
|
hasSystemInstruction = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if systemResult.Type == gjson.String {
|
} else if systemResult.Type == gjson.String {
|
||||||
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
|
systemInstructionJSON = []byte(`{"role":"user","parts":[{"text":""}]}`)
|
||||||
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
|
systemInstructionJSON, _ = sjson.SetBytes(systemInstructionJSON, "parts.0.text", systemResult.String())
|
||||||
hasSystemInstruction = true
|
hasSystemInstruction = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// contents
|
// contents
|
||||||
contentsJSON := "[]"
|
contentsJSON := []byte(`[]`)
|
||||||
hasContents := false
|
hasContents := false
|
||||||
|
|
||||||
|
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
|
||||||
|
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
|
||||||
|
toolNameByID := make(map[string]string)
|
||||||
|
|
||||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||||
if messagesResult.IsArray() {
|
if messagesResult.IsArray() {
|
||||||
messageResults := messagesResult.Array()
|
messageResults := messagesResult.Array()
|
||||||
@@ -83,8 +88,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if role == "assistant" {
|
if role == "assistant" {
|
||||||
role = "model"
|
role = "model"
|
||||||
}
|
}
|
||||||
clientContentJSON := `{"role":"","parts":[]}`
|
clientContentJSON := []byte(`{"role":"","parts":[]}`)
|
||||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
|
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "role", role)
|
||||||
contentsResult := messageResult.Get("content")
|
contentsResult := messageResult.Get("content")
|
||||||
if contentsResult.IsArray() {
|
if contentsResult.IsArray() {
|
||||||
contentResults := contentsResult.Array()
|
contentResults := contentsResult.Array()
|
||||||
@@ -143,15 +148,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Valid signature, send as thought block
|
// Valid signature, send as thought block
|
||||||
partJSON := `{}`
|
// Always include "text" field — Google Antigravity API requires it
|
||||||
partJSON, _ = sjson.Set(partJSON, "thought", true)
|
// even for redacted thinking where the text is empty.
|
||||||
if thinkingText != "" {
|
partJSON := []byte(`{}`)
|
||||||
partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
|
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
|
||||||
}
|
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
|
||||||
if signature != "" {
|
if signature != "" {
|
||||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
|
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
|
||||||
}
|
}
|
||||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||||
prompt := contentResult.Get("text").String()
|
prompt := contentResult.Get("text").String()
|
||||||
// Skip empty text parts to avoid Gemini API error:
|
// Skip empty text parts to avoid Gemini API error:
|
||||||
@@ -159,17 +164,21 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if prompt == "" {
|
if prompt == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
partJSON := `{}`
|
partJSON := []byte(`{}`)
|
||||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
|
||||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||||
// Antigravity API validates signatures, so dummy values are rejected.
|
// Antigravity API validates signatures, so dummy values are rejected.
|
||||||
|
|
||||||
functionName := contentResult.Get("name").String()
|
functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
|
||||||
argsResult := contentResult.Get("input")
|
argsResult := contentResult.Get("input")
|
||||||
functionID := contentResult.Get("id").String()
|
functionID := contentResult.Get("id").String()
|
||||||
|
|
||||||
|
if functionID != "" && functionName != "" {
|
||||||
|
toolNameByID[functionID] = functionName
|
||||||
|
}
|
||||||
|
|
||||||
// Handle both object and string input formats
|
// Handle both object and string input formats
|
||||||
var argsRaw string
|
var argsRaw string
|
||||||
if argsResult.IsObject() {
|
if argsResult.IsObject() {
|
||||||
@@ -183,138 +192,147 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
|
|
||||||
if argsRaw != "" {
|
if argsRaw != "" {
|
||||||
partJSON := `{}`
|
partJSON := []byte(`{}`)
|
||||||
|
|
||||||
// Use skip_thought_signature_validator for tool calls without valid thinking signature
|
// Use skip_thought_signature_validator for tool calls without valid thinking signature
|
||||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||||
// and also works for Claude through Antigravity API
|
// and also works for Claude through Antigravity API
|
||||||
const skipSentinel = "skip_thought_signature_validator"
|
const skipSentinel = "skip_thought_signature_validator"
|
||||||
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
|
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
|
||||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||||
} else {
|
} else {
|
||||||
// No valid signature - use skip sentinel to bypass validation
|
// No valid signature - use skip sentinel to bypass validation
|
||||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
|
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", skipSentinel)
|
||||||
}
|
}
|
||||||
|
|
||||||
if functionID != "" {
|
if functionID != "" {
|
||||||
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
|
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.id", functionID)
|
||||||
}
|
}
|
||||||
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
|
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.name", functionName)
|
||||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
|
partJSON, _ = sjson.SetRawBytes(partJSON, "functionCall.args", []byte(argsRaw))
|
||||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||||
}
|
}
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||||
toolCallID := contentResult.Get("tool_use_id").String()
|
toolCallID := contentResult.Get("tool_use_id").String()
|
||||||
if toolCallID != "" {
|
if toolCallID != "" {
|
||||||
funcName := toolCallID
|
funcName, ok := toolNameByID[toolCallID]
|
||||||
toolCallIDs := strings.Split(toolCallID, "-")
|
if !ok {
|
||||||
if len(toolCallIDs) > 1 {
|
// Fallback: derive a semantic name from the ID by stripping
|
||||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
// the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
|
||||||
|
// Only use the raw ID as a last resort when the heuristic produces an empty string.
|
||||||
|
parts := strings.Split(toolCallID, "-")
|
||||||
|
if len(parts) > 2 {
|
||||||
|
funcName = strings.Join(parts[:len(parts)-2], "-")
|
||||||
|
}
|
||||||
|
if funcName == "" {
|
||||||
|
funcName = toolCallID
|
||||||
|
}
|
||||||
|
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
|
||||||
}
|
}
|
||||||
functionResponseResult := contentResult.Get("content")
|
functionResponseResult := contentResult.Get("content")
|
||||||
|
|
||||||
functionResponseJSON := `{}`
|
functionResponseJSON := []byte(`{}`)
|
||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
|
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
|
||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
|
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", util.SanitizeFunctionName(funcName))
|
||||||
|
|
||||||
responseData := ""
|
responseData := ""
|
||||||
if functionResponseResult.Type == gjson.String {
|
if functionResponseResult.Type == gjson.String {
|
||||||
responseData = functionResponseResult.String()
|
responseData = functionResponseResult.String()
|
||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", responseData)
|
||||||
} else if functionResponseResult.IsArray() {
|
} else if functionResponseResult.IsArray() {
|
||||||
frResults := functionResponseResult.Array()
|
frResults := functionResponseResult.Array()
|
||||||
nonImageCount := 0
|
nonImageCount := 0
|
||||||
lastNonImageRaw := ""
|
lastNonImageRaw := ""
|
||||||
filteredJSON := "[]"
|
filteredJSON := []byte(`[]`)
|
||||||
imagePartsJSON := "[]"
|
imagePartsJSON := []byte(`[]`)
|
||||||
for _, fr := range frResults {
|
for _, fr := range frResults {
|
||||||
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
|
||||||
inlineDataJSON := `{}`
|
inlineDataJSON := []byte(`{}`)
|
||||||
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||||
}
|
}
|
||||||
if data := fr.Get("source.data").String(); data != "" {
|
if data := fr.Get("source.data").String(); data != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||||
}
|
}
|
||||||
|
|
||||||
imagePartJSON := `{}`
|
imagePartJSON := []byte(`{}`)
|
||||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
|
||||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
nonImageCount++
|
nonImageCount++
|
||||||
lastNonImageRaw = fr.Raw
|
lastNonImageRaw = fr.Raw
|
||||||
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
|
filteredJSON, _ = sjson.SetRawBytes(filteredJSON, "-1", []byte(fr.Raw))
|
||||||
}
|
}
|
||||||
|
|
||||||
if nonImageCount == 1 {
|
if nonImageCount == 1 {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
|
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(lastNonImageRaw))
|
||||||
} else if nonImageCount > 1 {
|
} else if nonImageCount > 1 {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
|
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", filteredJSON)
|
||||||
} else {
|
} else {
|
||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Place image data inside functionResponse.parts as inlineData
|
// Place image data inside functionResponse.parts as inlineData
|
||||||
// instead of as sibling parts in the outer content, to avoid
|
// instead of as sibling parts in the outer content, to avoid
|
||||||
// base64 data bloating the text context.
|
// base64 data bloating the text context.
|
||||||
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
|
if gjson.GetBytes(imagePartsJSON, "#").Int() > 0 {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if functionResponseResult.IsObject() {
|
} else if functionResponseResult.IsObject() {
|
||||||
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
|
||||||
inlineDataJSON := `{}`
|
inlineDataJSON := []byte(`{}`)
|
||||||
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||||
}
|
}
|
||||||
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
if data := functionResponseResult.Get("source.data").String(); data != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||||
}
|
}
|
||||||
|
|
||||||
imagePartJSON := `{}`
|
imagePartJSON := []byte(`{}`)
|
||||||
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
|
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
|
||||||
imagePartsJSON := "[]"
|
imagePartsJSON := []byte(`[]`)
|
||||||
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
|
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
|
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
|
||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||||
} else {
|
} else {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
|
||||||
}
|
}
|
||||||
} else if functionResponseResult.Raw != "" {
|
} else if functionResponseResult.Raw != "" {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
|
||||||
} else {
|
} else {
|
||||||
// Content field is missing entirely — .Raw is empty which
|
// Content field is missing entirely — .Raw is empty which
|
||||||
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
||||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
partJSON := `{}`
|
partJSON := []byte(`{}`)
|
||||||
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
|
partJSON, _ = sjson.SetRawBytes(partJSON, "functionResponse", functionResponseJSON)
|
||||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||||
}
|
}
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
||||||
sourceResult := contentResult.Get("source")
|
sourceResult := contentResult.Get("source")
|
||||||
if sourceResult.Get("type").String() == "base64" {
|
if sourceResult.Get("type").String() == "base64" {
|
||||||
inlineDataJSON := `{}`
|
inlineDataJSON := []byte(`{}`)
|
||||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
|
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
|
||||||
}
|
}
|
||||||
if data := sourceResult.Get("data").String(); data != "" {
|
if data := sourceResult.Get("data").String(); data != "" {
|
||||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
|
||||||
}
|
}
|
||||||
|
|
||||||
partJSON := `{}`
|
partJSON := []byte(`{}`)
|
||||||
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
|
partJSON, _ = sjson.SetRawBytes(partJSON, "inlineData", inlineDataJSON)
|
||||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reorder parts for 'model' role to ensure thinking block is first
|
// Reorder parts for 'model' role to ensure thinking block is first
|
||||||
if role == "model" {
|
if role == "model" {
|
||||||
partsResult := gjson.Get(clientContentJSON, "parts")
|
partsResult := gjson.GetBytes(clientContentJSON, "parts")
|
||||||
if partsResult.IsArray() {
|
if partsResult.IsArray() {
|
||||||
parts := partsResult.Array()
|
parts := partsResult.Array()
|
||||||
var thinkingParts []gjson.Result
|
var thinkingParts []gjson.Result
|
||||||
@@ -336,7 +354,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
for _, p := range otherParts {
|
for _, p := range otherParts {
|
||||||
newParts = append(newParts, p.Value())
|
newParts = append(newParts, p.Value())
|
||||||
}
|
}
|
||||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
|
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -344,33 +362,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
|
|
||||||
// Skip messages with empty parts array to avoid Gemini API error:
|
// Skip messages with empty parts array to avoid Gemini API error:
|
||||||
// "required oneof field 'data' must have one initialized field"
|
// "required oneof field 'data' must have one initialized field"
|
||||||
partsCheck := gjson.Get(clientContentJSON, "parts")
|
partsCheck := gjson.GetBytes(clientContentJSON, "parts")
|
||||||
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
|
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
|
||||||
hasContents = true
|
hasContents = true
|
||||||
} else if contentsResult.Type == gjson.String {
|
} else if contentsResult.Type == gjson.String {
|
||||||
prompt := contentsResult.String()
|
prompt := contentsResult.String()
|
||||||
partJSON := `{}`
|
partJSON := []byte(`{}`)
|
||||||
if prompt != "" {
|
if prompt != "" {
|
||||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
|
||||||
}
|
}
|
||||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
|
||||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
|
||||||
hasContents = true
|
hasContents = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tools
|
// tools
|
||||||
toolsJSON := ""
|
var toolsJSON []byte
|
||||||
toolDeclCount := 0
|
toolDeclCount := 0
|
||||||
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
|
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
|
||||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||||
if toolsResult.IsArray() {
|
if toolsResult.IsArray() {
|
||||||
toolsJSON = `[{"functionDeclarations":[]}]`
|
toolsJSON = []byte(`[{"functionDeclarations":[]}]`)
|
||||||
toolsResults := toolsResult.Array()
|
toolsResults := toolsResult.Array()
|
||||||
for i := 0; i < len(toolsResults); i++ {
|
for i := 0; i < len(toolsResults); i++ {
|
||||||
toolResult := toolsResults[i]
|
toolResult := toolsResults[i]
|
||||||
@@ -378,23 +396,24 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||||
// Sanitize the input schema for Antigravity API compatibility
|
// Sanitize the input schema for Antigravity API compatibility
|
||||||
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
|
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
|
||||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
|
||||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
|
||||||
for toolKey := range gjson.Parse(tool).Map() {
|
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||||
|
for toolKey := range gjson.ParseBytes(tool).Map() {
|
||||||
if util.InArray(allowedToolKeys, toolKey) {
|
if util.InArray(allowedToolKeys, toolKey) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tool, _ = sjson.Delete(tool, toolKey)
|
tool, _ = sjson.DeleteBytes(tool, toolKey)
|
||||||
}
|
}
|
||||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
|
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||||
toolDeclCount++
|
toolDeclCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build output Gemini CLI request JSON
|
// Build output Gemini CLI request JSON
|
||||||
out := `{"model":"","request":{"contents":[]}}`
|
out := []byte(`{"model":"","request":{"contents":[]}}`)
|
||||||
out, _ = sjson.Set(out, "model", modelName)
|
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||||
|
|
||||||
// Inject interleaved thinking hint when both tools and thinking are active
|
// Inject interleaved thinking hint when both tools and thinking are active
|
||||||
hasTools := toolDeclCount > 0
|
hasTools := toolDeclCount > 0
|
||||||
@@ -408,27 +427,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
|
|
||||||
if hasSystemInstruction {
|
if hasSystemInstruction {
|
||||||
// Append hint as a new part to existing system instruction
|
// Append hint as a new part to existing system instruction
|
||||||
hintPart := `{"text":""}`
|
hintPart := []byte(`{"text":""}`)
|
||||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
|
||||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
|
||||||
} else {
|
} else {
|
||||||
// Create new system instruction with hint
|
// Create new system instruction with hint
|
||||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
|
||||||
hintPart := `{"text":""}`
|
hintPart := []byte(`{"text":""}`)
|
||||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
|
||||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
|
||||||
hasSystemInstruction = true
|
hasSystemInstruction = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasSystemInstruction {
|
if hasSystemInstruction {
|
||||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
|
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstructionJSON)
|
||||||
}
|
}
|
||||||
if hasContents {
|
if hasContents {
|
||||||
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
|
out, _ = sjson.SetRawBytes(out, "request.contents", contentsJSON)
|
||||||
}
|
}
|
||||||
if toolDeclCount > 0 {
|
if toolDeclCount > 0 {
|
||||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
out, _ = sjson.SetRawBytes(out, "request.tools", toolsJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
// tool_choice
|
// tool_choice
|
||||||
@@ -445,15 +464,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
|
|
||||||
switch toolChoiceType {
|
switch toolChoiceType {
|
||||||
case "auto":
|
case "auto":
|
||||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
|
||||||
case "none":
|
case "none":
|
||||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
|
||||||
case "any":
|
case "any":
|
||||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||||
case "tool":
|
case "tool":
|
||||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
|
||||||
if toolChoiceName != "" {
|
if toolChoiceName != "" {
|
||||||
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
|
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -464,8 +483,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
case "enabled":
|
case "enabled":
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
case "adaptive", "auto":
|
case "adaptive", "auto":
|
||||||
// For adaptive thinking:
|
// For adaptive thinking:
|
||||||
@@ -477,28 +496,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||||
}
|
}
|
||||||
if effort != "" {
|
if effort != "" {
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||||
} else {
|
} else {
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||||
}
|
}
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
|
||||||
}
|
}
|
||||||
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
|
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
|
||||||
}
|
}
|
||||||
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
|
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
|
||||||
}
|
}
|
||||||
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
|
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", v.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
outBytes := []byte(out)
|
out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
|
||||||
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
|
|
||||||
|
|
||||||
return outBytes
|
return out
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -365,6 +365,17 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
|||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-3-5-sonnet-20240620",
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "get_weather-call-123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"location": "Paris"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
@@ -382,13 +393,177 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
|||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Check function response conversion
|
// Check function response conversion
|
||||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||||
if !funcResp.Exists() {
|
if !funcResp.Exists() {
|
||||||
t.Error("functionResponse should exist")
|
t.Error("functionResponse should exist")
|
||||||
}
|
}
|
||||||
if funcResp.Get("id").String() != "get_weather-call-123" {
|
if funcResp.Get("id").String() != "get_weather-call-123" {
|
||||||
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
|
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
|
||||||
}
|
}
|
||||||
|
if funcResp.Get("name").String() != "get_weather" {
|
||||||
|
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-haiku-4-5-20251001",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||||
|
"name": "Glob",
|
||||||
|
"input": {"pattern": "**/*.py"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||||
|
"name": "Bash",
|
||||||
|
"input": {"command": "ls"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||||
|
"content": "file1.py\nfile2.py"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
|
||||||
|
"content": "total 10"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||||
|
if !funcResp0.Exists() {
|
||||||
|
t.Fatal("first functionResponse should exist")
|
||||||
|
}
|
||||||
|
if got := funcResp0.Get("name").String(); got != "Glob" {
|
||||||
|
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
|
||||||
|
if !funcResp1.Exists() {
|
||||||
|
t.Fatal("second functionResponse should exist")
|
||||||
|
}
|
||||||
|
if got := funcResp1.Get("name").String(); got != "Bash" {
|
||||||
|
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-haiku-4-5-20251001",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "Read-1773420180464065165-1327",
|
||||||
|
"name": "Read",
|
||||||
|
"input": {"file_path": "/tmp/test.py"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "Read-1773420180464065165-1327",
|
||||||
|
"content": "file content here"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
if got := funcResp.Get("name").String(); got != "Read" {
|
||||||
|
t.Errorf("Expected name 'Read', got '%s'", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "get_weather-call-123",
|
||||||
|
"content": "22C sunny"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
if got := funcResp.Get("name").String(); got != "get_weather" {
|
||||||
|
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
|
||||||
|
"content": "result data"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||||
|
if !funcResp.Exists() {
|
||||||
|
t.Fatal("functionResponse should exist")
|
||||||
|
}
|
||||||
|
got := funcResp.Get("name").String()
|
||||||
|
if got == "" {
|
||||||
|
t.Error("functionResponse.name must not be empty")
|
||||||
|
}
|
||||||
|
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
|
||||||
|
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
|
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -43,6 +44,10 @@ type Params struct {
|
|||||||
|
|
||||||
// Signature caching support
|
// Signature caching support
|
||||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||||
|
|
||||||
|
// Reverse map: sanitized Gemini function name → original Claude tool name.
|
||||||
|
// Populated lazily on the first response chunk from the original request JSON.
|
||||||
|
ToolNameMap map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||||
@@ -63,13 +68,14 @@ var toolUseIDCounter uint64
|
|||||||
// - param: A pointer to a parameter object for maintaining state between calls
|
// - param: A pointer to a parameter object for maintaining state between calls
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
|
// - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
|
||||||
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &Params{
|
*param = &Params{
|
||||||
HasFirstResponse: false,
|
HasFirstResponse: false,
|
||||||
ResponseType: 0,
|
ResponseType: 0,
|
||||||
ResponseIndex: 0,
|
ResponseIndex: 0,
|
||||||
|
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||||
@@ -77,44 +83,44 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
params := (*param).(*Params)
|
params := (*param).(*Params)
|
||||||
|
|
||||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||||
output := ""
|
output := make([]byte, 0, 256)
|
||||||
// Only send final events if we have actually output content
|
// Only send final events if we have actually output content
|
||||||
if params.HasContent {
|
if params.HasContent {
|
||||||
appendFinalEvents(params, &output, true)
|
appendFinalEvents(params, &output, true)
|
||||||
return []string{
|
output = translatorcommon.AppendSSEEventString(output, "message_stop", `{"type":"message_stop"}`, 3)
|
||||||
output + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
|
return [][]byte{output}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return []string{}
|
return [][]byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
output := ""
|
output := make([]byte, 0, 1024)
|
||||||
|
appendEvent := func(event, payload string) {
|
||||||
|
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize the streaming session with a message_start event
|
// Initialize the streaming session with a message_start event
|
||||||
// This is only sent for the very first response chunk to establish the streaming session
|
// This is only sent for the very first response chunk to establish the streaming session
|
||||||
if !params.HasFirstResponse {
|
if !params.HasFirstResponse {
|
||||||
output = "event: message_start\n"
|
|
||||||
|
|
||||||
// Create the initial message structure with default values according to Claude Code API specification
|
// Create the initial message structure with default values according to Claude Code API specification
|
||||||
// This follows the Claude Code API specification for streaming message initialization
|
// This follows the Claude Code API specification for streaming message initialization
|
||||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
messageStartTemplate := []byte(`{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`)
|
||||||
|
|
||||||
// Use cpaUsageMetadata within the message_start event for Claude.
|
// Use cpaUsageMetadata within the message_start event for Claude.
|
||||||
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
|
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
|
||||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
||||||
}
|
}
|
||||||
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
|
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
|
||||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override default values with actual response metadata if available from the Gemini CLI response
|
// Override default values with actual response metadata if available from the Gemini CLI response
|
||||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||||
}
|
}
|
||||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
|
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
|
||||||
}
|
}
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
appendEvent("message_start", string(messageStartTemplate))
|
||||||
|
|
||||||
params.HasFirstResponse = true
|
params.HasFirstResponse = true
|
||||||
}
|
}
|
||||||
@@ -144,15 +150,13 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
params.CurrentThinkingText.Reset()
|
params.CurrentThinkingText.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
output = output + "event: content_block_delta\n"
|
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex)), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
appendEvent("content_block_delta", string(data))
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
params.HasContent = true
|
params.HasContent = true
|
||||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||||
params.CurrentThinkingText.WriteString(partTextResult.String())
|
params.CurrentThinkingText.WriteString(partTextResult.String())
|
||||||
output = output + "event: content_block_delta\n"
|
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
appendEvent("content_block_delta", string(data))
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
params.HasContent = true
|
params.HasContent = true
|
||||||
} else {
|
} else {
|
||||||
// Transition from another state to thinking
|
// Transition from another state to thinking
|
||||||
@@ -163,19 +167,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||||
// output = output + "\n\n\n"
|
// output = output + "\n\n\n"
|
||||||
}
|
}
|
||||||
output = output + "event: content_block_stop\n"
|
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
params.ResponseIndex++
|
params.ResponseIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a new thinking content block
|
// Start a new thinking content block
|
||||||
output = output + "event: content_block_start\n"
|
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex))
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex)
|
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
|
||||||
output = output + "\n\n\n"
|
appendEvent("content_block_delta", string(data))
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
params.ResponseType = 2 // Set state to thinking
|
params.ResponseType = 2 // Set state to thinking
|
||||||
params.HasContent = true
|
params.HasContent = true
|
||||||
// Start accumulating thinking text for signature caching
|
// Start accumulating thinking text for signature caching
|
||||||
@@ -188,9 +187,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
// Process regular text content (user-visible output)
|
// Process regular text content (user-visible output)
|
||||||
// Continue existing text block if already in content state
|
// Continue existing text block if already in content state
|
||||||
if params.ResponseType == 1 {
|
if params.ResponseType == 1 {
|
||||||
output = output + "event: content_block_delta\n"
|
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
appendEvent("content_block_delta", string(data))
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
params.HasContent = true
|
params.HasContent = true
|
||||||
} else {
|
} else {
|
||||||
// Transition from another state to text content
|
// Transition from another state to text content
|
||||||
@@ -201,19 +199,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||||
// output = output + "\n\n\n"
|
// output = output + "\n\n\n"
|
||||||
}
|
}
|
||||||
output = output + "event: content_block_stop\n"
|
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
params.ResponseIndex++
|
params.ResponseIndex++
|
||||||
}
|
}
|
||||||
if partTextResult.String() != "" {
|
if partTextResult.String() != "" {
|
||||||
// Start a new text content block
|
// Start a new text content block
|
||||||
output = output + "event: content_block_start\n"
|
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex))
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
|
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
|
||||||
output = output + "\n\n\n"
|
appendEvent("content_block_delta", string(data))
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
params.ResponseType = 1 // Set state to content
|
params.ResponseType = 1 // Set state to content
|
||||||
params.HasContent = true
|
params.HasContent = true
|
||||||
}
|
}
|
||||||
@@ -224,14 +217,12 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
// Handle function/tool calls from the AI model
|
// Handle function/tool calls from the AI model
|
||||||
// This processes tool usage requests and formats them for Claude Code API compatibility
|
// This processes tool usage requests and formats them for Claude Code API compatibility
|
||||||
params.HasToolUse = true
|
params.HasToolUse = true
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := util.RestoreSanitizedToolName(params.ToolNameMap, functionCallResult.Get("name").String())
|
||||||
|
|
||||||
// Handle state transitions when switching to function calls
|
// Handle state transitions when switching to function calls
|
||||||
// Close any existing function call block first
|
// Close any existing function call block first
|
||||||
if params.ResponseType == 3 {
|
if params.ResponseType == 3 {
|
||||||
output = output + "event: content_block_stop\n"
|
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
params.ResponseIndex++
|
params.ResponseIndex++
|
||||||
params.ResponseType = 0
|
params.ResponseType = 0
|
||||||
}
|
}
|
||||||
@@ -245,26 +236,21 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
// Close any other existing content block
|
// Close any other existing content block
|
||||||
if params.ResponseType != 0 {
|
if params.ResponseType != 0 {
|
||||||
output = output + "event: content_block_stop\n"
|
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
params.ResponseIndex++
|
params.ResponseIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a new tool use content block
|
// Start a new tool use content block
|
||||||
// This creates the structure for a function call in Claude Code format
|
// This creates the structure for a function call in Claude Code format
|
||||||
output = output + "event: content_block_start\n"
|
|
||||||
|
|
||||||
// Create the tool use block with unique ID and function details
|
// Create the tool use block with unique ID and function details
|
||||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex))
|
||||||
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
data, _ = sjson.SetBytes(data, "content_block.name", fcName)
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
appendEvent("content_block_start", string(data))
|
||||||
|
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
output = output + "event: content_block_delta\n"
|
data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
|
||||||
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
appendEvent("content_block_delta", string(data))
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
}
|
}
|
||||||
params.ResponseType = 3
|
params.ResponseType = 3
|
||||||
params.HasContent = true
|
params.HasContent = true
|
||||||
@@ -296,10 +282,10 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
appendFinalEvents(params, &output, false)
|
appendFinalEvents(params, &output, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{output}
|
return [][]byte{output}
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendFinalEvents(params *Params, output *string, force bool) {
|
func appendFinalEvents(params *Params, output *[]byte, force bool) {
|
||||||
if params.HasSentFinalEvents {
|
if params.HasSentFinalEvents {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -314,9 +300,7 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if params.ResponseType != 0 {
|
if params.ResponseType != 0 {
|
||||||
*output = *output + "event: content_block_stop\n"
|
*output = translatorcommon.AppendSSEEventString(*output, "content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex), 3)
|
||||||
*output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
|
||||||
*output = *output + "\n\n\n"
|
|
||||||
params.ResponseType = 0
|
params.ResponseType = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -329,18 +313,16 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
*output = *output + "event: message_delta\n"
|
delta := []byte(fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens))
|
||||||
*output = *output + "data: "
|
|
||||||
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
|
|
||||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||||
if params.CachedTokenCount > 0 {
|
if params.CachedTokenCount > 0 {
|
||||||
var err error
|
var err error
|
||||||
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
delta, err = sjson.SetBytes(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*output = *output + delta + "\n\n\n"
|
*output = translatorcommon.AppendSSEEventString(*output, "message_delta", string(delta), 3)
|
||||||
|
|
||||||
params.HasSentFinalEvents = true
|
params.HasSentFinalEvents = true
|
||||||
}
|
}
|
||||||
@@ -369,9 +351,9 @@ func resolveStopReason(params *Params) string {
|
|||||||
// - param: A pointer to a parameter object for the conversion.
|
// - param: A pointer to a parameter object for the conversion.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: A Claude-compatible JSON response.
|
// - []byte: A Claude-compatible JSON response.
|
||||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||||
_ = originalRequestRawJSON
|
toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||||
|
|
||||||
root := gjson.ParseBytes(rawJSON)
|
root := gjson.ParseBytes(rawJSON)
|
||||||
@@ -388,15 +370,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
responseJSON := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||||
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
|
responseJSON, _ = sjson.SetBytes(responseJSON, "id", root.Get("response.responseId").String())
|
||||||
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
responseJSON, _ = sjson.SetBytes(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||||
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.input_tokens", promptTokens)
|
||||||
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.output_tokens", outputTokens)
|
||||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||||
if cachedTokens > 0 {
|
if cachedTokens > 0 {
|
||||||
var err error
|
var err error
|
||||||
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
responseJSON, err = sjson.SetBytes(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||||
}
|
}
|
||||||
@@ -407,7 +389,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
if contentArrayInitialized {
|
if contentArrayInitialized {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
|
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content", []byte("[]"))
|
||||||
contentArrayInitialized = true
|
contentArrayInitialized = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -423,9 +405,9 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
ensureContentArray()
|
ensureContentArray()
|
||||||
block := `{"type":"text","text":""}`
|
block := []byte(`{"type":"text","text":""}`)
|
||||||
block, _ = sjson.Set(block, "text", textBuilder.String())
|
block, _ = sjson.SetBytes(block, "text", textBuilder.String())
|
||||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
|
||||||
textBuilder.Reset()
|
textBuilder.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -434,12 +416,12 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
ensureContentArray()
|
ensureContentArray()
|
||||||
block := `{"type":"thinking","thinking":""}`
|
block := []byte(`{"type":"thinking","thinking":""}`)
|
||||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
||||||
if thinkingSignature != "" {
|
if thinkingSignature != "" {
|
||||||
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
block, _ = sjson.SetBytes(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||||
}
|
}
|
||||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
|
||||||
thinkingBuilder.Reset()
|
thinkingBuilder.Reset()
|
||||||
thinkingSignature = ""
|
thinkingSignature = ""
|
||||||
}
|
}
|
||||||
@@ -473,18 +455,18 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
flushText()
|
flushText()
|
||||||
hasToolCall = true
|
hasToolCall = true
|
||||||
|
|
||||||
name := functionCall.Get("name").String()
|
name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
|
||||||
toolIDCounter++
|
toolIDCounter++
|
||||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
|
||||||
|
|
||||||
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
|
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
|
||||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
|
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(args.Raw))
|
||||||
}
|
}
|
||||||
|
|
||||||
ensureContentArray()
|
ensureContentArray()
|
||||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
|
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", toolBlock)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -508,17 +490,17 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
|
responseJSON, _ = sjson.SetBytes(responseJSON, "stop_reason", stopReason)
|
||||||
|
|
||||||
if promptTokens == 0 && outputTokens == 0 {
|
if promptTokens == 0 && outputTokens == 0 {
|
||||||
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
||||||
responseJSON, _ = sjson.Delete(responseJSON, "usage")
|
responseJSON, _ = sjson.DeleteBytes(responseJSON, "usage")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return responseJSON
|
return responseJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
|
||||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
return translatorcommon.ClaudeInputTokensJSON(count)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,10 +34,10 @@ import (
|
|||||||
// - []byte: The transformed request data in Gemini API format
|
// - []byte: The transformed request data in Gemini API format
|
||||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||||
rawJSON := inputRawJSON
|
rawJSON := inputRawJSON
|
||||||
template := ""
|
template := `{"project":"","request":{},"model":""}`
|
||||||
template = `{"project":"","request":{},"model":""}`
|
templateBytes, _ := sjson.SetRawBytes([]byte(template), "request", rawJSON)
|
||||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
templateBytes, _ = sjson.SetBytes(templateBytes, "model", modelName)
|
||||||
template, _ = sjson.Set(template, "model", modelName)
|
template = string(templateBytes)
|
||||||
template, _ = sjson.Delete(template, "request.model")
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
|
|
||||||
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
||||||
@@ -47,7 +47,8 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
|
|
||||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||||
if systemInstructionResult.Exists() {
|
if systemInstructionResult.Exists() {
|
||||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
templateBytes, _ = sjson.SetRawBytes([]byte(template), "request.systemInstruction", []byte(systemInstructionResult.Raw))
|
||||||
|
template = string(templateBytes)
|
||||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||||
}
|
}
|
||||||
rawJSON = []byte(template)
|
rawJSON = []byte(template)
|
||||||
@@ -138,30 +139,47 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
// FunctionCallGroup represents a group of function calls and their responses
|
// FunctionCallGroup represents a group of function calls and their responses
|
||||||
type FunctionCallGroup struct {
|
type FunctionCallGroup struct {
|
||||||
ResponsesNeeded int
|
ResponsesNeeded int
|
||||||
|
CallNames []string // ordered function call names for backfilling empty response names
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
|
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
|
||||||
// Falls back to a minimal "functionResponse" object when parsing fails.
|
// Falls back to a minimal "functionResponse" object when parsing fails.
|
||||||
func parseFunctionResponseRaw(response gjson.Result) string {
|
// fallbackName is used when the response's own name is empty.
|
||||||
|
func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string {
|
||||||
if response.IsObject() && gjson.Valid(response.Raw) {
|
if response.IsObject() && gjson.Valid(response.Raw) {
|
||||||
return response.Raw
|
raw := response.Raw
|
||||||
|
name := response.Get("functionResponse.name").String()
|
||||||
|
if strings.TrimSpace(name) == "" && fallbackName != "" {
|
||||||
|
updated, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
|
||||||
|
raw = string(updated)
|
||||||
|
}
|
||||||
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("parse function response failed, using fallback")
|
log.Debugf("parse function response failed, using fallback")
|
||||||
funcResp := response.Get("functionResponse")
|
funcResp := response.Get("functionResponse")
|
||||||
if funcResp.Exists() {
|
if funcResp.Exists() {
|
||||||
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||||
fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
|
name := funcResp.Get("name").String()
|
||||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
if strings.TrimSpace(name) == "" {
|
||||||
if id := funcResp.Get("id").String(); id != "" {
|
name = fallbackName
|
||||||
fr, _ = sjson.Set(fr, "functionResponse.id", id)
|
|
||||||
}
|
}
|
||||||
return fr
|
fr, _ = sjson.SetBytes(fr, "functionResponse.name", name)
|
||||||
|
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
||||||
|
if id := funcResp.Get("id").String(); id != "" {
|
||||||
|
fr, _ = sjson.SetBytes(fr, "functionResponse.id", id)
|
||||||
|
}
|
||||||
|
return string(fr)
|
||||||
}
|
}
|
||||||
|
|
||||||
fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
|
useName := fallbackName
|
||||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
|
if useName == "" {
|
||||||
return fr
|
useName = "unknown"
|
||||||
|
}
|
||||||
|
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
||||||
|
fr, _ = sjson.SetBytes(fr, "functionResponse.name", useName)
|
||||||
|
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", response.String())
|
||||||
|
return string(fr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||||
@@ -188,7 +206,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize data structures for processing and grouping
|
// Initialize data structures for processing and grouping
|
||||||
contentsWrapper := `{"contents":[]}`
|
contentsWrapper := []byte(`{"contents":[]}`)
|
||||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||||
|
|
||||||
@@ -211,30 +229,26 @@ func fixCLIToolResponse(input string) (string, error) {
|
|||||||
if len(responsePartsInThisContent) > 0 {
|
if len(responsePartsInThisContent) > 0 {
|
||||||
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||||
|
|
||||||
// Check if any pending groups can be satisfied
|
// Check if pending groups can be satisfied (FIFO: oldest group first)
|
||||||
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
|
||||||
group := pendingGroups[i]
|
group := pendingGroups[0]
|
||||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
pendingGroups = pendingGroups[1:]
|
||||||
// Take the needed responses for this group
|
|
||||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
|
||||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
|
||||||
|
|
||||||
// Create merged function response content
|
// Take the needed responses for this group
|
||||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||||
for _, response := range groupResponses {
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||||
partRaw := parseFunctionResponseRaw(response)
|
|
||||||
if partRaw != "" {
|
// Create merged function response content
|
||||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
||||||
}
|
for ri, response := range groupResponses {
|
||||||
|
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
||||||
|
if partRaw != "" {
|
||||||
|
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
||||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
||||||
}
|
|
||||||
|
|
||||||
// Remove this group as it's been satisfied
|
|
||||||
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,25 +257,26 @@ func fixCLIToolResponse(input string) (string, error) {
|
|||||||
|
|
||||||
// If this is a model with function calls, create a new group
|
// If this is a model with function calls, create a new group
|
||||||
if role == "model" {
|
if role == "model" {
|
||||||
functionCallsCount := 0
|
var callNames []string
|
||||||
parts.ForEach(func(_, part gjson.Result) bool {
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
if part.Get("functionCall").Exists() {
|
if part.Get("functionCall").Exists() {
|
||||||
functionCallsCount++
|
callNames = append(callNames, part.Get("functionCall.name").String())
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
if functionCallsCount > 0 {
|
if len(callNames) > 0 {
|
||||||
// Add the model content
|
// Add the model content
|
||||||
if !value.IsObject() {
|
if !value.IsObject() {
|
||||||
log.Warnf("failed to parse model content")
|
log.Warnf("failed to parse model content")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||||
|
|
||||||
// Create a new group for tracking responses
|
// Create a new group for tracking responses
|
||||||
group := &FunctionCallGroup{
|
group := &FunctionCallGroup{
|
||||||
ResponsesNeeded: functionCallsCount,
|
ResponsesNeeded: len(callNames),
|
||||||
|
CallNames: callNames,
|
||||||
}
|
}
|
||||||
pendingGroups = append(pendingGroups, group)
|
pendingGroups = append(pendingGroups, group)
|
||||||
} else {
|
} else {
|
||||||
@@ -270,7 +285,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
|||||||
log.Warnf("failed to parse content")
|
log.Warnf("failed to parse content")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Non-model content (user, etc.)
|
// Non-model content (user, etc.)
|
||||||
@@ -278,7 +293,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
|||||||
log.Warnf("failed to parse content")
|
log.Warnf("failed to parse content")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -290,23 +305,22 @@ func fixCLIToolResponse(input string) (string, error) {
|
|||||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||||
|
|
||||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
||||||
for _, response := range groupResponses {
|
for ri, response := range groupResponses {
|
||||||
partRaw := parseFunctionResponseRaw(response)
|
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
||||||
if partRaw != "" {
|
if partRaw != "" {
|
||||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
||||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the original JSON with the new contents
|
// Update the original JSON with the new contents
|
||||||
result := input
|
result, _ := sjson.SetRawBytes([]byte(input), "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
|
||||||
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
|
|
||||||
|
|
||||||
return result, nil
|
return string(result), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -171,3 +171,257 @@ func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
|
|||||||
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) {
|
||||||
|
// When the Amp client sends functionResponse with an empty name,
|
||||||
|
// fixCLIToolResponse should backfill it from the corresponding functionCall.
|
||||||
|
input := `{
|
||||||
|
"model": "gemini-3-pro-preview",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := fixCLIToolResponse(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := gjson.Get(result, "request.contents").Array()
|
||||||
|
var funcContent gjson.Result
|
||||||
|
for _, c := range contents {
|
||||||
|
if c.Get("role").String() == "function" {
|
||||||
|
funcContent = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !funcContent.Exists() {
|
||||||
|
t.Fatal("function role content should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
name := funcContent.Get("parts.0.functionResponse.name").String()
|
||||||
|
if name != "Bash" {
|
||||||
|
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFixCLIToolResponse_BackfillsMultipleEmptyNames(t *testing.T) {
|
||||||
|
// Parallel function calls: both responses have empty names.
|
||||||
|
input := `{
|
||||||
|
"model": "gemini-3-pro-preview",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
|
||||||
|
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
|
||||||
|
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := fixCLIToolResponse(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := gjson.Get(result, "request.contents").Array()
|
||||||
|
var funcContent gjson.Result
|
||||||
|
for _, c := range contents {
|
||||||
|
if c.Get("role").String() == "function" {
|
||||||
|
funcContent = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !funcContent.Exists() {
|
||||||
|
t.Fatal("function role content should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := funcContent.Get("parts").Array()
|
||||||
|
if len(parts) != 2 {
|
||||||
|
t.Fatalf("Expected 2 function response parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
name0 := parts[0].Get("functionResponse.name").String()
|
||||||
|
name1 := parts[1].Get("functionResponse.name").String()
|
||||||
|
if name0 != "Read" {
|
||||||
|
t.Errorf("Expected first response name 'Read', got '%s'", name0)
|
||||||
|
}
|
||||||
|
if name1 != "Grep" {
|
||||||
|
t.Errorf("Expected second response name 'Grep', got '%s'", name1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFixCLIToolResponse_PreservesExistingName(t *testing.T) {
|
||||||
|
// When functionResponse already has a valid name, it should be preserved.
|
||||||
|
input := `{
|
||||||
|
"model": "gemini-3-pro-preview",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"functionCall": {"name": "Bash", "args": {}}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := fixCLIToolResponse(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := gjson.Get(result, "request.contents").Array()
|
||||||
|
var funcContent gjson.Result
|
||||||
|
for _, c := range contents {
|
||||||
|
if c.Get("role").String() == "function" {
|
||||||
|
funcContent = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !funcContent.Exists() {
|
||||||
|
t.Fatal("function role content should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
name := funcContent.Get("parts.0.functionResponse.name").String()
|
||||||
|
if name != "Bash" {
|
||||||
|
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFixCLIToolResponse_MoreResponsesThanCalls(t *testing.T) {
|
||||||
|
// If there are more function responses than calls, unmatched extras are discarded by grouping.
|
||||||
|
input := `{
|
||||||
|
"model": "gemini-3-pro-preview",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"functionCall": {"name": "Bash", "args": {}}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
|
||||||
|
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := fixCLIToolResponse(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := gjson.Get(result, "request.contents").Array()
|
||||||
|
var funcContent gjson.Result
|
||||||
|
for _, c := range contents {
|
||||||
|
if c.Get("role").String() == "function" {
|
||||||
|
funcContent = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !funcContent.Exists() {
|
||||||
|
t.Fatal("function role content should exist in output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// First response should be backfilled from the call
|
||||||
|
name0 := funcContent.Get("parts.0.functionResponse.name").String()
|
||||||
|
if name0 != "Bash" {
|
||||||
|
t.Errorf("Expected first response name 'Bash', got '%s'", name0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFixCLIToolResponse_MultipleGroupsFIFO(t *testing.T) {
|
||||||
|
// Two sequential function call groups should be matched FIFO.
|
||||||
|
input := `{
|
||||||
|
"model": "gemini-3-pro-preview",
|
||||||
|
"request": {
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"functionCall": {"name": "Read", "args": {}}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{"functionResponse": {"name": "", "response": {"result": "file content"}}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"functionCall": {"name": "Grep", "args": {}}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{"functionResponse": {"name": "", "response": {"result": "match"}}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := fixCLIToolResponse(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fixCLIToolResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := gjson.Get(result, "request.contents").Array()
|
||||||
|
var funcContents []gjson.Result
|
||||||
|
for _, c := range contents {
|
||||||
|
if c.Get("role").String() == "function" {
|
||||||
|
funcContents = append(funcContents, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(funcContents) != 2 {
|
||||||
|
t.Fatalf("Expected 2 function contents, got %d", len(funcContents))
|
||||||
|
}
|
||||||
|
|
||||||
|
name0 := funcContents[0].Get("parts.0.functionResponse.name").String()
|
||||||
|
name1 := funcContents[1].Get("parts.0.functionResponse.name").String()
|
||||||
|
if name0 != "Read" {
|
||||||
|
t.Errorf("Expected first group name 'Read', got '%s'", name0)
|
||||||
|
}
|
||||||
|
if name1 != "Grep" {
|
||||||
|
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ package gemini
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
|
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -29,8 +29,8 @@ import (
|
|||||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - []string: The transformed request data in Gemini API format
|
// - [][]byte: The transformed response data in Gemini API format.
|
||||||
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
|
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
|
||||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||||
}
|
}
|
||||||
@@ -44,22 +44,22 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
|
|||||||
chunk = restoreUsageMetadata(chunk)
|
chunk = restoreUsageMetadata(chunk)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
chunkTemplate := "[]"
|
chunkTemplate := []byte("[]")
|
||||||
responseResult := gjson.ParseBytes(chunk)
|
responseResult := gjson.ParseBytes(chunk)
|
||||||
if responseResult.IsArray() {
|
if responseResult.IsArray() {
|
||||||
responseResultItems := responseResult.Array()
|
responseResultItems := responseResult.Array()
|
||||||
for i := 0; i < len(responseResultItems); i++ {
|
for i := 0; i < len(responseResultItems); i++ {
|
||||||
responseResultItem := responseResultItems[i]
|
responseResultItem := responseResultItems[i]
|
||||||
if responseResultItem.Get("response").Exists() {
|
if responseResultItem.Get("response").Exists() {
|
||||||
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
chunk = []byte(chunkTemplate)
|
chunk = chunkTemplate
|
||||||
}
|
}
|
||||||
return []string{string(chunk)}
|
return [][]byte{chunk}
|
||||||
}
|
}
|
||||||
return []string{}
|
return [][]byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
|
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
|
||||||
@@ -73,18 +73,18 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
|
|||||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: A Gemini-compatible JSON response containing the response data
|
// - []byte: A Gemini-compatible JSON response containing the response data.
|
||||||
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
|
||||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||||
if responseResult.Exists() {
|
if responseResult.Exists() {
|
||||||
chunk := restoreUsageMetadata([]byte(responseResult.Raw))
|
chunk := restoreUsageMetadata([]byte(responseResult.Raw))
|
||||||
return string(chunk)
|
return chunk
|
||||||
}
|
}
|
||||||
return string(rawJSON)
|
return rawJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
func GeminiTokenCount(ctx context.Context, count int64) []byte {
|
||||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
return translatorcommon.GeminiTokenCountJSON(count)
|
||||||
}
|
}
|
||||||
|
|
||||||
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.
|
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ func TestConvertAntigravityResponseToGeminiNonStream(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
|
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
|
||||||
if result != tt.expected {
|
if string(result) != tt.expected {
|
||||||
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", result, tt.expected)
|
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", string(result), tt.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -87,8 +87,8 @@ func TestConvertAntigravityResponseToGeminiStream(t *testing.T) {
|
|||||||
if len(results) != 1 {
|
if len(results) != 1 {
|
||||||
t.Fatalf("expected 1 result, got %d", len(results))
|
t.Fatalf("expected 1 result, got %d", len(results))
|
||||||
}
|
}
|
||||||
if results[0] != tt.expected {
|
if string(results[0]) != tt.expected {
|
||||||
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", results[0], tt.expected)
|
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", string(results[0]), tt.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -286,7 +286,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fid := tc.Get("id").String()
|
fid := tc.Get("id").String()
|
||||||
fname := tc.Get("function.name").String()
|
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
|
||||||
fargs := tc.Get("function.arguments").String()
|
fargs := tc.Get("function.arguments").String()
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||||
@@ -309,7 +309,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
for _, fid := range fIDs {
|
for _, fid := range fIDs {
|
||||||
if name, ok := tcID2Name[fid]; ok {
|
if name, ok := tcID2Name[fid]; ok {
|
||||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
|
||||||
resp := toolResponses[fid]
|
resp := toolResponses[fid]
|
||||||
if resp == "" {
|
if resp == "" {
|
||||||
resp = "{}"
|
resp = "{}"
|
||||||
@@ -354,33 +354,39 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if errRename != nil {
|
if errRename != nil {
|
||||||
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
|
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
|
||||||
var errSet error
|
var errSet error
|
||||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
fnRaw = string(fnRawBytes)
|
||||||
|
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
fnRaw = string(fnRawBytes)
|
||||||
} else {
|
} else {
|
||||||
fnRaw = renamed
|
fnRaw = renamed
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
var errSet error
|
var errSet error
|
||||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
fnRaw = string(fnRawBytes)
|
||||||
|
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
fnRaw = string(fnRawBytes)
|
||||||
}
|
}
|
||||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
fnRawBytes := []byte(fnRaw)
|
||||||
|
fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String()))
|
||||||
|
fnRaw, _ = sjson.Delete(string(fnRawBytes), "strict")
|
||||||
if !hasFunction {
|
if !hasFunction {
|
||||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||||
@@ -26,6 +27,7 @@ type convertCliResponseToOpenAIChatParams struct {
|
|||||||
FunctionIndex int
|
FunctionIndex int
|
||||||
SawToolCall bool // Tracks if any tool call was seen in the entire stream
|
SawToolCall bool // Tracks if any tool call was seen in the entire stream
|
||||||
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
|
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
|
||||||
|
SanitizedNameMap map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
@@ -44,25 +46,29 @@ var functionCallIDCounter uint64
|
|||||||
// - param: A pointer to a parameter object for maintaining state between calls
|
// - param: A pointer to a parameter object for maintaining state between calls
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
// - [][]byte: A slice of OpenAI-compatible JSON responses
|
||||||
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &convertCliResponseToOpenAIChatParams{
|
*param = &convertCliResponseToOpenAIChatParams{
|
||||||
UnixTimestamp: 0,
|
UnixTimestamp: 0,
|
||||||
FunctionIndex: 0,
|
FunctionIndex: 0,
|
||||||
|
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil {
|
||||||
|
(*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
|
||||||
|
}
|
||||||
|
|
||||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||||
return []string{}
|
return [][]byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the OpenAI SSE template.
|
// Initialize the OpenAI SSE template.
|
||||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
|
||||||
|
|
||||||
// Extract and set the model version.
|
// Extract and set the model version.
|
||||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
template, _ = sjson.SetBytes(template, "model", modelVersionResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set the creation timestamp.
|
// Extract and set the creation timestamp.
|
||||||
@@ -71,14 +77,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
|
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set the response ID.
|
// Extract and set the response ID.
|
||||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
|
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
|
||||||
@@ -90,21 +96,21 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
// Include cached token count if present (indicates prompt caching is working)
|
// Include cached token count if present (indicates prompt caching is working)
|
||||||
if cachedTokenCount > 0 {
|
if cachedTokenCount > 0 {
|
||||||
var err error
|
var err error
|
||||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
|
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
|
||||||
}
|
}
|
||||||
@@ -141,33 +147,33 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||||
if partResult.Get("thought").Bool() {
|
if partResult.Get("thought").Bool() {
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
|
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent)
|
||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
|
template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent)
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||||
} else if functionCallResult.Exists() {
|
} else if functionCallResult.Exists() {
|
||||||
// Handle function call content.
|
// Handle function call content.
|
||||||
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
|
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
|
||||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
|
||||||
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
||||||
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
||||||
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
||||||
functionCallIndex = len(toolCallsResult.Array())
|
functionCallIndex = len(toolCallsResult.Array())
|
||||||
} else {
|
} else {
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
|
||||||
}
|
}
|
||||||
|
|
||||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
functionCallTemplate := []byte(`{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`)
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String())
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||||
} else if inlineDataResult.Exists() {
|
} else if inlineDataResult.Exists() {
|
||||||
data := inlineDataResult.Get("data").String()
|
data := inlineDataResult.Get("data").String()
|
||||||
if data == "" {
|
if data == "" {
|
||||||
@@ -181,16 +187,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
mimeType = "image/png"
|
mimeType = "image/png"
|
||||||
}
|
}
|
||||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
imagesResult := gjson.GetBytes(template, "choices.0.delta.images")
|
||||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`))
|
||||||
}
|
}
|
||||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array())
|
||||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`)
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex)
|
||||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL)
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -212,11 +218,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
} else {
|
} else {
|
||||||
finishReason = "stop"
|
finishReason = "stop"
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{template}
|
return [][]byte{template}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
|
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
|
||||||
@@ -231,11 +237,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
// - param: A pointer to a parameter object for the conversion
|
// - param: A pointer to a parameter object for the conversion
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
|
||||||
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||||
if responseResult.Exists() {
|
if responseResult.Exists() {
|
||||||
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
|
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
|
||||||
}
|
}
|
||||||
return ""
|
return []byte{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
|||||||
if len(result1) != 1 {
|
if len(result1) != 1 {
|
||||||
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
|
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
|
||||||
}
|
}
|
||||||
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
|
||||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||||
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
|
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
|
||||||
}
|
}
|
||||||
@@ -33,13 +33,13 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
|||||||
if len(result2) != 1 {
|
if len(result2) != 1 {
|
||||||
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
|
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
|
||||||
}
|
}
|
||||||
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||||
if fr2 != "tool_calls" {
|
if fr2 != "tool_calls" {
|
||||||
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
|
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify native_finish_reason is lowercase upstream value
|
// Verify native_finish_reason is lowercase upstream value
|
||||||
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
|
nfr2 := gjson.GetBytes(result2[0], "choices.0.native_finish_reason").String()
|
||||||
if nfr2 != "stop" {
|
if nfr2 != "stop" {
|
||||||
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
|
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
|
||||||
}
|
}
|
||||||
@@ -58,7 +58,7 @@ func TestFinishReasonStopForNormalText(t *testing.T) {
|
|||||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
// Verify finish_reason is "stop" (no tool calls were made)
|
// Verify finish_reason is "stop" (no tool calls were made)
|
||||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||||
if fr != "stop" {
|
if fr != "stop" {
|
||||||
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
|
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
|
||||||
}
|
}
|
||||||
@@ -77,7 +77,7 @@ func TestFinishReasonMaxTokens(t *testing.T) {
|
|||||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
// Verify finish_reason is "max_tokens"
|
// Verify finish_reason is "max_tokens"
|
||||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||||
if fr != "max_tokens" {
|
if fr != "max_tokens" {
|
||||||
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
|
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
|
|||||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
|
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
|
||||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
|
||||||
if fr != "tool_calls" {
|
if fr != "tool_calls" {
|
||||||
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
|
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
|
||||||
}
|
}
|
||||||
@@ -111,7 +111,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
|||||||
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
// Verify no finish_reason on intermediate chunk
|
// Verify no finish_reason on intermediate chunk
|
||||||
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
|
||||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||||
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
|
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
|
||||||
}
|
}
|
||||||
@@ -121,7 +121,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
|||||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
// Verify no finish_reason on intermediate chunk
|
// Verify no finish_reason on intermediate chunk
|
||||||
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
|
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason")
|
||||||
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
|
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
|
||||||
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
|
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||||
if responseResult.Exists() {
|
if responseResult.Exists() {
|
||||||
rawJSON = []byte(responseResult.Raw)
|
rawJSON = []byte(responseResult.Raw)
|
||||||
@@ -15,7 +15,7 @@ func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName
|
|||||||
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||||
if responseResult.Exists() {
|
if responseResult.Exists() {
|
||||||
rawJSON = []byte(responseResult.Raw)
|
rawJSON = []byte(responseResult.Raw)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
|
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
|
||||||
"github.com/tidwall/sjson"
|
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format.
|
// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format.
|
||||||
@@ -23,15 +23,13 @@ import (
|
|||||||
// - param: A pointer to a parameter object for maintaining state between calls
|
// - param: A pointer to a parameter object for maintaining state between calls
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
|
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
|
||||||
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||||
outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||||
// Wrap each converted response in a "response" object to match Gemini CLI API structure
|
// Wrap each converted response in a "response" object to match Gemini CLI API structure
|
||||||
newOutputs := make([]string, 0)
|
newOutputs := make([][]byte, 0, len(outputs))
|
||||||
for i := 0; i < len(outputs); i++ {
|
for i := 0; i < len(outputs); i++ {
|
||||||
json := `{"response": {}}`
|
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
|
||||||
output, _ := sjson.SetRaw(json, "response", outputs[i])
|
|
||||||
newOutputs = append(newOutputs, output)
|
|
||||||
}
|
}
|
||||||
return newOutputs
|
return newOutputs
|
||||||
}
|
}
|
||||||
@@ -47,15 +45,13 @@ func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, ori
|
|||||||
// - param: A pointer to a parameter object for the conversion
|
// - param: A pointer to a parameter object for the conversion
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: A Gemini-compatible JSON response wrapped in a response object
|
// - []byte: A Gemini-compatible JSON response wrapped in a response object
|
||||||
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||||
strJSON := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
out := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||||
// Wrap the converted response in a "response" object to match Gemini CLI API structure
|
// Wrap the converted response in a "response" object to match Gemini CLI API structure
|
||||||
json := `{"response": {}}`
|
return translatorcommon.WrapGeminiCLIResponse(out)
|
||||||
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
|
|
||||||
return strJSON
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiCLITokenCount(ctx context.Context, count int64) string {
|
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
|
||||||
return GeminiTokenCount(ctx, count)
|
return GeminiTokenCount(ctx, count)
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user