mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-10 08:17:58 +00:00
Compare commits
219 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
938af75954 | ||
|
|
1dba2d0f81 | ||
|
|
730809d8ea | ||
|
|
5e81b65f2f | ||
|
|
c42480a574 | ||
|
|
55c146a0e7 | ||
|
|
ad8e3964ff | ||
|
|
e9dc576409 | ||
|
|
941334da79 | ||
|
|
d54f816363 | ||
|
|
f43d25def1 | ||
|
|
a279192881 | ||
|
|
6a43d7285c | ||
|
|
578c312660 | ||
|
|
6bb9bf3132 | ||
|
|
343a2fc2f7 | ||
|
|
12b967118b | ||
|
|
70efd4e016 | ||
|
|
f5aa68ecda | ||
|
|
9a5f142c33 | ||
|
|
d390b95b76 | ||
|
|
d1f6224b70 | ||
|
|
fcc59d606d | ||
|
|
91e7591955 | ||
|
|
4607356333 | ||
|
|
9a9ed99072 | ||
|
|
5ae38584b8 | ||
|
|
c8b7e2b8d6 | ||
|
|
cad45ffa33 | ||
|
|
6a27bceec0 | ||
|
|
163d68318f | ||
|
|
0ea768011b | ||
|
|
8b9dbe10f0 | ||
|
|
341b4beea1 | ||
|
|
bea13f9724 | ||
|
|
9f5bdfaa31 | ||
|
|
9eabdd09db | ||
|
|
c3f8dc362e | ||
|
|
b85120873b | ||
|
|
6f58518c69 | ||
|
|
000fcb15fa | ||
|
|
ea43361492 | ||
|
|
c1818f197b | ||
|
|
b0653cec7b | ||
|
|
22a1a24cf5 | ||
|
|
7223fee2de | ||
|
|
ada8e2905e | ||
|
|
4ba10531da | ||
|
|
3774b56e9f | ||
|
|
c2d4137fb9 | ||
|
|
2ee938acaf | ||
|
|
8d5e470e1f | ||
|
|
65e9e892a4 | ||
|
|
3882494878 | ||
|
|
088c1d07f4 | ||
|
|
8430b28cfa | ||
|
|
f3ab8f4bc5 | ||
|
|
0e4f189c2e | ||
|
|
98509f615c | ||
|
|
e7a66ae504 | ||
|
|
754b126944 | ||
|
|
ae37ccffbf | ||
|
|
42c062bb5b | ||
|
|
87bf0b73d5 | ||
|
|
f389667ec3 | ||
|
|
29dba0399b | ||
|
|
a824e7cd0b | ||
|
|
140faef7dc | ||
|
|
adb580b344 | ||
|
|
06405f2129 | ||
|
|
b849bf79d6 | ||
|
|
59af2c57b1 | ||
|
|
d1fd2c4ad4 | ||
|
|
b6c6379bfa | ||
|
|
8f0e66b72e | ||
|
|
f63cf6ff7a | ||
|
|
d2419ed49d | ||
|
|
516d22c695 | ||
|
|
73cda6e836 | ||
|
|
0805989ee5 | ||
|
|
9b5ce8c64f | ||
|
|
058793c73a | ||
|
|
75da02af55 | ||
|
|
ab9ebea592 | ||
|
|
7ee37ee4b9 | ||
|
|
837afffb31 | ||
|
|
03a1bac898 | ||
|
|
3171d524f0 | ||
|
|
3e78a8d500 | ||
|
|
fcba912cc4 | ||
|
|
7170eeea5f | ||
|
|
e3eb048c7a | ||
|
|
a59e92435b | ||
|
|
108895fc04 | ||
|
|
abc293c642 | ||
|
|
da3a498a28 | ||
|
|
bb44671845 | ||
|
|
09e480036a | ||
|
|
249f969110 | ||
|
|
4f8acec2d8 | ||
|
|
34339f61ee | ||
|
|
4045378cb4 | ||
|
|
2df35449fe | ||
|
|
c744179645 | ||
|
|
9720b03a6b | ||
|
|
f2c0f3d325 | ||
|
|
4f99bc54f1 | ||
|
|
913f4a9c5f | ||
|
|
25d1c18a3f | ||
|
|
d09dd4d0b2 | ||
|
|
474fb042da | ||
|
|
8435c3d7be | ||
|
|
e783d0a62e | ||
|
|
b05f575e9b | ||
|
|
f5e9f01811 | ||
|
|
ff7dbb5867 | ||
|
|
e34b2b4f1d | ||
|
|
15c2f274ea | ||
|
|
37249339ac | ||
|
|
c422d16beb | ||
|
|
66cd50f603 | ||
|
|
caa529c282 | ||
|
|
51a4379bf4 | ||
|
|
acf98ed10e | ||
|
|
d1c07a091e | ||
|
|
c1a8adf1ab | ||
|
|
08e078fc25 | ||
|
|
105a21548f | ||
|
|
1734aa1664 | ||
|
|
ca11b236a7 | ||
|
|
6fdff8227d | ||
|
|
330e12d3c2 | ||
|
|
bd09c0bf09 | ||
|
|
b468ca79c3 | ||
|
|
d2c7e4e96a | ||
|
|
1c7003ff68 | ||
|
|
1b44364e78 | ||
|
|
ec77f4a4f5 | ||
|
|
f611dd6e96 | ||
|
|
07b7c1a1e0 | ||
|
|
51fd58d74f | ||
|
|
faae9c2f7c | ||
|
|
bc3a6e4646 | ||
|
|
b09b03e35e | ||
|
|
16231947e7 | ||
|
|
39b9a38fbc | ||
|
|
bd855abec9 | ||
|
|
7c3c2e9f64 | ||
|
|
c10f8ae2e2 | ||
|
|
a0bf33eca6 | ||
|
|
88dd9c715d | ||
|
|
a3e21df814 | ||
|
|
d3b94c9241 | ||
|
|
c1d7599829 | ||
|
|
d11936f292 | ||
|
|
17363edf25 | ||
|
|
279cbbbb8a | ||
|
|
486cd4c343 | ||
|
|
25feceb783 | ||
|
|
d26752250d | ||
|
|
b15453c369 | ||
|
|
04ba8c8bc3 | ||
|
|
6570692291 | ||
|
|
f73d55ddaa | ||
|
|
13aa5b3375 | ||
|
|
0fcc02fbea | ||
|
|
c03883ccf0 | ||
|
|
134a9eac9d | ||
|
|
6d8de0ade4 | ||
|
|
1587ff5e74 | ||
|
|
f033d3a6df | ||
|
|
145e0e0b5d | ||
|
|
f8d1bc06ea | ||
|
|
d5930f4e44 | ||
|
|
9b7d7021af | ||
|
|
e41c22ef44 | ||
|
|
5fc2bd393e | ||
|
|
55271403fb | ||
|
|
36fba66619 | ||
|
|
66eb12294a | ||
|
|
73b22ec29b | ||
|
|
c31ae2f3b5 | ||
|
|
76b53d6b5b | ||
|
|
a34dfed378 | ||
|
|
b9b127a7ea | ||
|
|
2741e7b7b3 | ||
|
|
1767a56d4f | ||
|
|
779e6c2d2f | ||
|
|
73c831747b | ||
|
|
b8b89f34f4 | ||
|
|
e5d3541b5a | ||
|
|
79755e76ea | ||
|
|
35f158d526 | ||
|
|
6962e09dd9 | ||
|
|
4c4cbd44da | ||
|
|
26eca8b6ba | ||
|
|
62b17f40a1 | ||
|
|
511b8a992e | ||
|
|
0ab977c236 | ||
|
|
224f0de353 | ||
|
|
d54de441d3 | ||
|
|
754f3bcbc3 | ||
|
|
36973d4a6f | ||
|
|
c89d19b300 | ||
|
|
cc32f5ff61 | ||
|
|
fbff68b9e0 | ||
|
|
7e1a543b79 | ||
|
|
74b862d8b8 | ||
|
|
36efcc6e28 | ||
|
|
a337ecf35c | ||
|
|
5c817a9b42 | ||
|
|
e08f68ed7c | ||
|
|
f09ed25fd3 | ||
|
|
5da0decef6 | ||
|
|
e166e56249 | ||
|
|
5f58248016 | ||
|
|
07d6689d87 | ||
|
|
14cb2b95c6 | ||
|
|
fdeef48498 |
81
.github/workflows/agents-md-guard.yml
vendored
Normal file
81
.github/workflows/agents-md-guard.yml
vendored
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
name: agents-md-guard
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- synchronize
|
||||||
|
- reopened
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
close-when-agents-md-changed:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Detect AGENTS.md changes and close PR
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const prNumber = context.payload.pull_request.number;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
|
||||||
|
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
per_page: 100,
|
||||||
|
});
|
||||||
|
|
||||||
|
const touchesAgentsMd = (path) =>
|
||||||
|
typeof path === "string" &&
|
||||||
|
(path === "AGENTS.md" || path.endsWith("/AGENTS.md"));
|
||||||
|
|
||||||
|
const touched = files.filter(
|
||||||
|
(f) => touchesAgentsMd(f.filename) || touchesAgentsMd(f.previous_filename),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (touched.length === 0) {
|
||||||
|
core.info("No AGENTS.md changes detected.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const changedList = touched
|
||||||
|
.map((f) =>
|
||||||
|
f.previous_filename && f.previous_filename !== f.filename
|
||||||
|
? `- ${f.previous_filename} -> ${f.filename}`
|
||||||
|
: `- ${f.filename}`,
|
||||||
|
)
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
const body = [
|
||||||
|
"This repository does not allow modifying `AGENTS.md` in pull requests.",
|
||||||
|
"",
|
||||||
|
"Detected changes:",
|
||||||
|
changedList,
|
||||||
|
"",
|
||||||
|
"Please revert these changes and open a new PR without touching `AGENTS.md`.",
|
||||||
|
].join("\n");
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.issues.createComment({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
issue_number: prNumber,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.warning(`Failed to comment on PR #${prNumber}: ${error.message}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
await github.rest.pulls.update({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
state: "closed",
|
||||||
|
});
|
||||||
|
|
||||||
|
core.setFailed("PR modifies AGENTS.md");
|
||||||
73
.github/workflows/auto-retarget-main-pr-to-dev.yml
vendored
Normal file
73
.github/workflows/auto-retarget-main-pr-to-dev.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
name: auto-retarget-main-pr-to-dev
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- reopened
|
||||||
|
- edited
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
retarget:
|
||||||
|
if: github.actor != 'github-actions[bot]'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Retarget PR base to dev
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const pr = context.payload.pull_request;
|
||||||
|
const prNumber = pr.number;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
|
||||||
|
const baseRef = pr.base?.ref;
|
||||||
|
const headRef = pr.head?.ref;
|
||||||
|
const desiredBase = "dev";
|
||||||
|
|
||||||
|
if (baseRef !== "main") {
|
||||||
|
core.info(`PR #${prNumber} base is ${baseRef}; nothing to do.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (headRef === desiredBase) {
|
||||||
|
core.info(`PR #${prNumber} is ${desiredBase} -> main; skipping retarget.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
core.info(`Retargeting PR #${prNumber} base from ${baseRef} to ${desiredBase}.`);
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.pulls.update({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
base: desiredBase,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.setFailed(`Failed to retarget PR #${prNumber} to ${desiredBase}: ${error.message}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const body = [
|
||||||
|
`This pull request targeted \`${baseRef}\`.`,
|
||||||
|
"",
|
||||||
|
`The base branch has been automatically changed to \`${desiredBase}\`.`,
|
||||||
|
].join("\n");
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.issues.createComment({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
issue_number: prNumber,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.warning(`Failed to comment on PR #${prNumber}: ${error.message}`);
|
||||||
|
}
|
||||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -37,15 +37,16 @@ GEMINI.md
|
|||||||
|
|
||||||
# Tooling metadata
|
# Tooling metadata
|
||||||
.vscode/*
|
.vscode/*
|
||||||
|
.worktrees/
|
||||||
.codex/*
|
.codex/*
|
||||||
.claude/*
|
.claude/*
|
||||||
.gemini/*
|
.gemini/*
|
||||||
.serena/*
|
.serena/*
|
||||||
.agent/*
|
.agent/*
|
||||||
.agents/*
|
.agents/*
|
||||||
.agents/*
|
|
||||||
.opencode/*
|
.opencode/*
|
||||||
.idea/*
|
.idea/*
|
||||||
|
.beads/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
_bmad-output/*
|
_bmad-output/*
|
||||||
@@ -54,4 +55,10 @@ _bmad-output/*
|
|||||||
# macOS
|
# macOS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
._*
|
._*
|
||||||
|
|
||||||
|
# Opencode
|
||||||
|
.beads/
|
||||||
|
.opencode/
|
||||||
|
.cli-proxy-api/
|
||||||
|
.venv/
|
||||||
*.bak
|
*.bak
|
||||||
|
|||||||
58
AGENTS.md
Normal file
58
AGENTS.md
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
Go 1.26+ proxy server providing OpenAI/Gemini/Claude/Codex compatible APIs with OAuth and round-robin load balancing.
|
||||||
|
|
||||||
|
## Repository
|
||||||
|
- GitHub: https://github.com/router-for-me/CLIProxyAPI
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
```bash
|
||||||
|
gofmt -w . # Format (required after Go changes)
|
||||||
|
go build -o cli-proxy-api ./cmd/server # Build
|
||||||
|
go run ./cmd/server # Run dev server
|
||||||
|
go test ./... # Run all tests
|
||||||
|
go test -v -run TestName ./path/to/pkg # Run single test
|
||||||
|
go build -o test-output ./cmd/server && rm test-output # Verify compile (REQUIRED after changes)
|
||||||
|
```
|
||||||
|
- Common flags: `--config <path>`, `--tui`, `--standalone`, `--local-model`, `--no-browser`, `--oauth-callback-port <port>`
|
||||||
|
|
||||||
|
## Config
|
||||||
|
- Default config: `config.yaml` (template: `config.example.yaml`)
|
||||||
|
- `.env` is auto-loaded from the working directory
|
||||||
|
- Auth material defaults under `auths/`
|
||||||
|
- Storage backends: file-based default; optional Postgres/git/object store (`PGSTORE_*`, `GITSTORE_*`, `OBJECTSTORE_*`)
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
- `cmd/server/` — Server entrypoint
|
||||||
|
- `internal/api/` — Gin HTTP API (routes, middleware, modules)
|
||||||
|
- `internal/api/modules/amp/` — Amp integration (Amp-style routes + reverse proxy)
|
||||||
|
- `internal/thinking/` — Main thinking/reasoning pipeline. `ApplyThinking()` (apply.go) parses suffixes (`suffix.go`, suffix overrides body), normalizes config to canonical `ThinkingConfig` (`types.go`), normalizes and validates centrally (`validate.go`/`convert.go`), then applies provider-specific output via `ProviderApplier`. Do not break this "canonical representation → per-provider translation" architecture.
|
||||||
|
- `internal/runtime/executor/` — Per-provider runtime executors (incl. Codex WebSocket)
|
||||||
|
- `internal/translator/` — Provider protocol translators (and shared `common`)
|
||||||
|
- `internal/registry/` — Model registry + remote updater (`StartModelsUpdater`); `--local-model` disables remote updates
|
||||||
|
- `internal/store/` — Storage implementations and secret resolution
|
||||||
|
- `internal/managementasset/` — Config snapshots and management assets
|
||||||
|
- `internal/cache/` — Request signature caching
|
||||||
|
- `internal/watcher/` — Config hot-reload and watchers
|
||||||
|
- `internal/wsrelay/` — WebSocket relay sessions
|
||||||
|
- `internal/usage/` — Usage and token accounting
|
||||||
|
- `internal/tui/` — Bubbletea terminal UI (`--tui`, `--standalone`)
|
||||||
|
- `sdk/cliproxy/` — Embeddable SDK entry (service/builder/watchers/pipeline)
|
||||||
|
- `test/` — Cross-module integration tests
|
||||||
|
|
||||||
|
## Code Conventions
|
||||||
|
- Keep changes small and simple (KISS)
|
||||||
|
- Comments in English only
|
||||||
|
- If editing code that already contains non-English comments, translate them to English (don’t add new non-English comments)
|
||||||
|
- For user-visible strings, keep the existing language used in that file/area
|
||||||
|
- New Markdown docs should be in English unless the file is explicitly language-specific (e.g. `README_CN.md`)
|
||||||
|
- As a rule, do not make standalone changes to `internal/translator/`. You may modify it only as part of broader changes elsewhere.
|
||||||
|
- If a task requires changing only `internal/translator/`, run `gh repo view --json viewerPermission -q .viewerPermission` to confirm you have `WRITE`, `MAINTAIN`, or `ADMIN`. If you do, you may proceed; otherwise, file a GitHub issue including the goal, rationale, and the intended implementation code, then stop further work.
|
||||||
|
- `internal/runtime/executor/` should contain executors and their unit tests only. Place any helper/supporting files under `internal/runtime/executor/helps/`.
|
||||||
|
- Follow `gofmt`; keep imports goimports-style; wrap errors with context where helpful
|
||||||
|
- Do not use `log.Fatal`/`log.Fatalf` (terminates the process); prefer returning errors and logging via logrus
|
||||||
|
- Shadowed variables: use method suffix (`errStart := server.Start()`)
|
||||||
|
- Wrap defer errors: `defer func() { if err := f.Close(); err != nil { log.Errorf(...) } }()`
|
||||||
|
- Use logrus structured logging; avoid leaking secrets/tokens in logs
|
||||||
|
- Avoid panics in HTTP handlers; prefer logged errors and meaningful HTTP status codes
|
||||||
|
- Timeouts are allowed only during credential acquisition; after an upstream connection is established, do not set timeouts for any subsequent network behavior. Intentional exceptions that must remain allowed are the Codex websocket liveness deadlines in `internal/runtime/executor/codex_websockets_executor.go`, the wsrelay session deadlines in `internal/wsrelay/session.go`, the management APICall timeout in `internal/api/handlers/management/api_tools.go`, and the `cmd/fetch_antigravity_models` utility timeouts
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# CLIProxyAPI Plus
|
# CLIProxyAPI Plus
|
||||||
|
|
||||||
[English](README.md) | 中文 | [日本語](README_JA.md)
|
[English](README.md) | 中文
|
||||||
|
|
||||||
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
||||||
|
|
||||||
|
|||||||
187
README_JA.md
187
README_JA.md
@@ -1,187 +0,0 @@
|
|||||||
# CLI Proxy API
|
|
||||||
|
|
||||||
[English](README.md) | [中文](README_CN.md) | 日本語
|
|
||||||
|
|
||||||
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
|
|
||||||
|
|
||||||
OAuth経由でOpenAI Codex(GPTモデル)およびClaude Codeもサポートしています。
|
|
||||||
|
|
||||||
ローカルまたはマルチアカウントのCLIアクセスを、OpenAI(Responses含む)/Gemini/Claude互換のクライアントやSDKで利用できます。
|
|
||||||
|
|
||||||
## スポンサー
|
|
||||||
|
|
||||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
|
||||||
|
|
||||||
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
|
|
||||||
|
|
||||||
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7および(GLM-5はProユーザーのみ利用可能)モデルを10以上の人気AIコーディングツール(Claude Code、Cline、Roo Codeなど)で利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
|
|
||||||
|
|
||||||
GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<table>
|
|
||||||
<tbody>
|
|
||||||
<tr>
|
|
||||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
|
||||||
<td>PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
|
||||||
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
|
||||||
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割(90% OFF)</b> という驚異的な価格でご利用いただけます!</td>
|
|
||||||
</tr>
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
## 概要
|
|
||||||
|
|
||||||
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
|
|
||||||
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
|
|
||||||
- OAuthログインによるClaude Codeサポート
|
|
||||||
- OAuthログインによるQwen Codeサポート
|
|
||||||
- OAuthログインによるiFlowサポート
|
|
||||||
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
|
|
||||||
- ストリーミングおよび非ストリーミングレスポンス
|
|
||||||
- 関数呼び出し/ツールのサポート
|
|
||||||
- マルチモーダル入力サポート(テキストと画像)
|
|
||||||
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
|
||||||
- シンプルなCLI認証フロー(Gemini、OpenAI、Claude、QwenおよびiFlow)
|
|
||||||
- Generative Language APIキーのサポート
|
|
||||||
- AI Studioビルドのマルチアカウント負荷分散
|
|
||||||
- Gemini CLIのマルチアカウント負荷分散
|
|
||||||
- Claude Codeのマルチアカウント負荷分散
|
|
||||||
- Qwen Codeのマルチアカウント負荷分散
|
|
||||||
- iFlowのマルチアカウント負荷分散
|
|
||||||
- OpenAI Codexのマルチアカウント負荷分散
|
|
||||||
- 設定によるOpenAI互換アップストリームプロバイダー(例:OpenRouter)
|
|
||||||
- プロキシ埋め込み用の再利用可能なGo SDK(`docs/sdk-usage.md`を参照)
|
|
||||||
|
|
||||||
## はじめに
|
|
||||||
|
|
||||||
CLIProxyAPIガイド:[https://help.router-for.me/](https://help.router-for.me/)
|
|
||||||
|
|
||||||
## 管理API
|
|
||||||
|
|
||||||
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
|
|
||||||
|
|
||||||
## Amp CLIサポート
|
|
||||||
|
|
||||||
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます:
|
|
||||||
|
|
||||||
- Ampの APIパターン用のプロバイダールートエイリアス(`/api/provider/{provider}/v1...`)
|
|
||||||
- OAuth認証およびアカウント機能用の管理プロキシ
|
|
||||||
- 自動ルーティングによるスマートモデルフォールバック
|
|
||||||
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
|
||||||
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
|
||||||
|
|
||||||
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
|
|
||||||
|
|
||||||
## SDKドキュメント
|
|
||||||
|
|
||||||
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
|
|
||||||
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
|
|
||||||
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
|
|
||||||
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
|
|
||||||
- カスタムプロバイダーの例:`examples/custom-provider`
|
|
||||||
|
|
||||||
## コントリビューション
|
|
||||||
|
|
||||||
コントリビューションを歓迎します!お気軽にPull Requestを送ってください。
|
|
||||||
|
|
||||||
1. リポジトリをフォーク
|
|
||||||
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`)
|
|
||||||
3. 変更をコミット(`git commit -m 'Add some amazing feature'`)
|
|
||||||
4. ブランチにプッシュ(`git push origin feature/amazing-feature`)
|
|
||||||
5. Pull Requestを作成
|
|
||||||
|
|
||||||
## 関連プロジェクト
|
|
||||||
|
|
||||||
CLIProxyAPIをベースにした以下のプロジェクトがあります:
|
|
||||||
|
|
||||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
|
||||||
|
|
||||||
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
|
|
||||||
|
|
||||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
|
||||||
|
|
||||||
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
|
|
||||||
|
|
||||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
|
||||||
|
|
||||||
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデル(Gemini、Codex、Antigravity)を即座に切り替えるCLIラッパー - APIキー不要
|
|
||||||
|
|
||||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
|
||||||
|
|
||||||
CLIProxyAPI管理用のmacOSネイティブGUI:OAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
|
|
||||||
|
|
||||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
|
||||||
|
|
||||||
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
|
|
||||||
|
|
||||||
### [CodMate](https://github.com/loocor/CodMate)
|
|
||||||
|
|
||||||
CLI AIセッション(Codex、Claude Code、Gemini CLI)を管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
|
|
||||||
|
|
||||||
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
|
||||||
|
|
||||||
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
|
|
||||||
|
|
||||||
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
|
||||||
|
|
||||||
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
|
|
||||||
|
|
||||||
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
|
|
||||||
|
|
||||||
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
|
|
||||||
|
|
||||||
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
|
||||||
|
|
||||||
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
|
|
||||||
|
|
||||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
|
||||||
|
|
||||||
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替え(Main / Plus)、自動ダウンロードおよび自動更新に対応
|
|
||||||
|
|
||||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
|
||||||
|
|
||||||
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
|
|
||||||
|
|
||||||
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
|
|
||||||
|
|
||||||
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
|
|
||||||
|
|
||||||
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
|
||||||
|
|
||||||
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
|
|
||||||
|
|
||||||
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
|
|
||||||
|
|
||||||
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LAN(ローカルエリアネットワーク)を介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
|
||||||
|
|
||||||
## その他の選択肢
|
|
||||||
|
|
||||||
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです:
|
|
||||||
|
|
||||||
### [9Router](https://github.com/decolua/9router)
|
|
||||||
|
|
||||||
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換(OpenAI/Claude/Gemini/Ollama)、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツール(Cursor、Claude Code、Cline、RooCode)のサポートをゼロから構築 - APIキー不要
|
|
||||||
|
|
||||||
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
|
|
||||||
|
|
||||||
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
|
|
||||||
|
|
||||||
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:スマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
|
||||||
|
|
||||||
## ライセンス
|
|
||||||
|
|
||||||
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。
|
|
||||||
BIN
assets/lingtrue.png
Normal file
BIN
assets/lingtrue.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
@@ -26,6 +26,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
@@ -188,7 +189,7 @@ func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
|
|||||||
httpReq.Close = true
|
httpReq.Close = true
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
|
httpReq.Header.Set("User-Agent", misc.AntigravityUserAgent())
|
||||||
|
|
||||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
|
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ func main() {
|
|||||||
var codeBuddyLogin bool
|
var codeBuddyLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
|
var vertexImportPrefix string
|
||||||
var configPath string
|
var configPath string
|
||||||
var password string
|
var password string
|
||||||
var tuiMode bool
|
var tuiMode bool
|
||||||
@@ -139,6 +140,7 @@ func main() {
|
|||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
|
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
|
||||||
flag.StringVar(&password, "password", "", "")
|
flag.StringVar(&password, "password", "", "")
|
||||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||||
@@ -188,6 +190,7 @@ func main() {
|
|||||||
gitStoreRemoteURL string
|
gitStoreRemoteURL string
|
||||||
gitStoreUser string
|
gitStoreUser string
|
||||||
gitStorePassword string
|
gitStorePassword string
|
||||||
|
gitStoreBranch string
|
||||||
gitStoreLocalPath string
|
gitStoreLocalPath string
|
||||||
gitStoreInst *store.GitTokenStore
|
gitStoreInst *store.GitTokenStore
|
||||||
gitStoreRoot string
|
gitStoreRoot string
|
||||||
@@ -257,6 +260,9 @@ func main() {
|
|||||||
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
||||||
gitStoreLocalPath = value
|
gitStoreLocalPath = value
|
||||||
}
|
}
|
||||||
|
if value, ok := lookupEnv("GITSTORE_GIT_BRANCH", "gitstore_git_branch"); ok {
|
||||||
|
gitStoreBranch = value
|
||||||
|
}
|
||||||
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
||||||
useObjectStore = true
|
useObjectStore = true
|
||||||
objectStoreEndpoint = value
|
objectStoreEndpoint = value
|
||||||
@@ -391,7 +397,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
||||||
authDir := filepath.Join(gitStoreRoot, "auths")
|
authDir := filepath.Join(gitStoreRoot, "auths")
|
||||||
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword, gitStoreBranch)
|
||||||
gitStoreInst.SetBaseDir(authDir)
|
gitStoreInst.SetBaseDir(authDir)
|
||||||
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
||||||
log.Errorf("failed to prepare git token store: %v", errRepo)
|
log.Errorf("failed to prepare git token store: %v", errRepo)
|
||||||
@@ -510,7 +516,7 @@ func main() {
|
|||||||
|
|
||||||
if vertexImport != "" {
|
if vertexImport != "" {
|
||||||
// Handle Vertex service account import
|
// Handle Vertex service account import
|
||||||
cmd.DoVertexImport(cfg, vertexImport)
|
cmd.DoVertexImport(cfg, vertexImport, vertexImportPrefix)
|
||||||
} else if login {
|
} else if login {
|
||||||
// Handle Google/Gemini login
|
// Handle Google/Gemini login
|
||||||
cmd.DoLogin(cfg, projectID, options)
|
cmd.DoLogin(cfg, projectID, options)
|
||||||
@@ -596,6 +602,7 @@ func main() {
|
|||||||
if standalone {
|
if standalone {
|
||||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
misc.StartAntigravityVersionUpdater(context.Background())
|
||||||
if !localModel {
|
if !localModel {
|
||||||
registry.StartModelsUpdater(context.Background())
|
registry.StartModelsUpdater(context.Background())
|
||||||
}
|
}
|
||||||
@@ -671,6 +678,7 @@ func main() {
|
|||||||
} else {
|
} else {
|
||||||
// Start the main proxy service
|
// Start the main proxy service
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
misc.StartAntigravityVersionUpdater(context.Background())
|
||||||
if !localModel {
|
if !localModel {
|
||||||
registry.StartModelsUpdater(context.Background())
|
registry.StartModelsUpdater(context.Background())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,10 +92,14 @@ max-retry-credentials: 0
|
|||||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||||
max-retry-interval: 30
|
max-retry-interval: 30
|
||||||
|
|
||||||
|
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
|
||||||
|
disable-cooling: false
|
||||||
|
|
||||||
# Quota exceeded behavior
|
# Quota exceeded behavior
|
||||||
quota-exceeded:
|
quota-exceeded:
|
||||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||||
|
antigravity-credits: true # Whether to retry Antigravity quota_exhausted 429s once with enabledCreditTypes=["GOOGLE_ONE_AI"]
|
||||||
|
|
||||||
# Routing strategy for selecting credentials when multiple match.
|
# Routing strategy for selecting credentials when multiple match.
|
||||||
routing:
|
routing:
|
||||||
@@ -104,6 +108,10 @@ routing:
|
|||||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
ws-auth: false
|
ws-auth: false
|
||||||
|
|
||||||
|
# When true, enable Gemini CLI internal endpoints (/v1internal:*).
|
||||||
|
# Default is false for safety.
|
||||||
|
enable-gemini-cli-endpoint: false
|
||||||
|
|
||||||
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
||||||
nonstream-keepalive-interval: 0
|
nonstream-keepalive-interval: 0
|
||||||
|
|
||||||
@@ -177,6 +185,8 @@ nonstream-keepalive-interval: 0
|
|||||||
# - "API"
|
# - "API"
|
||||||
# - "proxy"
|
# - "proxy"
|
||||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||||
|
# experimental-cch-signing: false # optional: default is false; when true, sign the final /v1/messages body using the current Claude Code cch algorithm
|
||||||
|
# # keep this disabled unless you explicitly need the behavior, so upstream seed changes fall back to legacy proxy behavior
|
||||||
|
|
||||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||||
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
||||||
@@ -313,6 +323,10 @@ nonstream-keepalive-interval: 0
|
|||||||
# These aliases rename model IDs for both model listing and request routing.
|
# These aliases rename model IDs for both model listing and request routing.
|
||||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||||
|
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||||
|
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||||
|
# you select the protocol surface, but inference backend selection can still follow the resolved
|
||||||
|
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
|
||||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||||
# oauth-model-alias:
|
# oauth-model-alias:
|
||||||
# antigravity:
|
# antigravity:
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -83,6 +83,7 @@ require (
|
|||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
github.com/muesli/termenv v0.16.0 // indirect
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
|
github.com/pierrec/xxHash v0.1.5
|
||||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/rs/xid v1.5.0 // indirect
|
github.com/rs/xid v1.5.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -154,6 +154,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
|||||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
|
github.com/pierrec/xxHash v0.1.5 h1:n/jBpwTHiER4xYvK3/CdPVnLDPchj8eTJFFLUb4QHBo=
|
||||||
|
github.com/pierrec/xxHash v0.1.5/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
|
||||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fxamacker/cbor/v2"
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
@@ -700,6 +701,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
|||||||
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
||||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||||
}
|
}
|
||||||
|
if h != nil && h.cfg != nil {
|
||||||
|
if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" {
|
||||||
|
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if h != nil && h.cfg != nil {
|
if h != nil && h.cfg != nil {
|
||||||
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
||||||
@@ -722,6 +728,123 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
|||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type apiKeyConfigEntry interface {
|
||||||
|
GetAPIKey() string
|
||||||
|
GetBaseURL() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T {
|
||||||
|
if auth == nil || len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
attrKey, attrBase := "", ""
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||||
|
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||||
|
}
|
||||||
|
for i := range entries {
|
||||||
|
entry := &entries[i]
|
||||||
|
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
|
||||||
|
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
|
||||||
|
if attrKey != "" && attrBase != "" {
|
||||||
|
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||||
|
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if attrKey != "" {
|
||||||
|
for i := range entries {
|
||||||
|
entry := &entries[i]
|
||||||
|
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
authKind, authAccount := auth.AccountInfo()
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs := auth.Attributes
|
||||||
|
compatName := ""
|
||||||
|
providerKey := ""
|
||||||
|
if len(attrs) > 0 {
|
||||||
|
compatName = strings.TrimSpace(attrs["compat_name"])
|
||||||
|
providerKey = strings.TrimSpace(attrs["provider_key"])
|
||||||
|
}
|
||||||
|
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
||||||
|
return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(strings.TrimSpace(auth.Provider)) {
|
||||||
|
case "gemini":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
case "claude":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
case "codex":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
apiKey = strings.TrimSpace(apiKey)
|
||||||
|
if apiKey == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
candidates := make([]string, 0, 3)
|
||||||
|
if v := strings.TrimSpace(compatName); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(providerKey); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(auth.Provider); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cfg.OpenAICompatibility {
|
||||||
|
compat := &cfg.OpenAICompatibility[i]
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
|
||||||
|
for j := range compat.APIKeyEntries {
|
||||||
|
entry := &compat.APIKeyEntries[j]
|
||||||
|
if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||||
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||||
if errBuild != nil {
|
if errBuild != nil {
|
||||||
|
|||||||
@@ -58,6 +58,105 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
|
GeminiKey: []config.GeminiKey{{
|
||||||
|
APIKey: "gemini-key",
|
||||||
|
ProxyURL: "http://gemini-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
ClaudeKey: []config.ClaudeKey{{
|
||||||
|
APIKey: "claude-key",
|
||||||
|
ProxyURL: "http://claude-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
CodexKey: []config.CodexKey{{
|
||||||
|
APIKey: "codex-key",
|
||||||
|
ProxyURL: "http://codex-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
OpenAICompatibility: []config.OpenAICompatibility{{
|
||||||
|
Name: "bohe",
|
||||||
|
BaseURL: "https://bohe.example.com",
|
||||||
|
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{
|
||||||
|
APIKey: "compat-key",
|
||||||
|
ProxyURL: "http://compat-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
auth *coreauth.Auth
|
||||||
|
wantProxy string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "gemini",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "gemini",
|
||||||
|
Attributes: map[string]string{"api_key": "gemini-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://gemini-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "claude",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{"api_key": "claude-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://claude-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "codex",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Attributes: map[string]string{"api_key": "codex-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://codex-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "openai-compatibility",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "bohe",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "compat-key",
|
||||||
|
"compat_name": "bohe",
|
||||||
|
"provider_key": "bohe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantProxy: "http://compat-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
transport := h.apiCallTransport(tc.auth)
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errRequest != nil {
|
||||||
|
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != tc.wantProxy {
|
||||||
|
t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
|
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -1047,6 +1047,7 @@ func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Aut
|
|||||||
auth.Runtime = existing.Runtime
|
auth.Runtime = existing.Runtime
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
coreauth.ApplyCustomHeadersFromMetadata(auth)
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1129,7 +1130,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
|
// PatchAuthFileFields updates editable fields (prefix, proxy_url, headers, priority, note) of an auth file.
|
||||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||||
if h.authManager == nil {
|
if h.authManager == nil {
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||||
@@ -1137,11 +1138,12 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var req struct {
|
var req struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Prefix *string `json:"prefix"`
|
Prefix *string `json:"prefix"`
|
||||||
ProxyURL *string `json:"proxy_url"`
|
ProxyURL *string `json:"proxy_url"`
|
||||||
Priority *int `json:"priority"`
|
Headers map[string]string `json:"headers"`
|
||||||
Note *string `json:"note"`
|
Priority *int `json:"priority"`
|
||||||
|
Note *string `json:"note"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||||
@@ -1177,13 +1179,107 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
|||||||
|
|
||||||
changed := false
|
changed := false
|
||||||
if req.Prefix != nil {
|
if req.Prefix != nil {
|
||||||
targetAuth.Prefix = *req.Prefix
|
prefix := strings.TrimSpace(*req.Prefix)
|
||||||
|
targetAuth.Prefix = prefix
|
||||||
|
if targetAuth.Metadata == nil {
|
||||||
|
targetAuth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
if prefix == "" {
|
||||||
|
delete(targetAuth.Metadata, "prefix")
|
||||||
|
} else {
|
||||||
|
targetAuth.Metadata["prefix"] = prefix
|
||||||
|
}
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
if req.ProxyURL != nil {
|
if req.ProxyURL != nil {
|
||||||
targetAuth.ProxyURL = *req.ProxyURL
|
proxyURL := strings.TrimSpace(*req.ProxyURL)
|
||||||
|
targetAuth.ProxyURL = proxyURL
|
||||||
|
if targetAuth.Metadata == nil {
|
||||||
|
targetAuth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
if proxyURL == "" {
|
||||||
|
delete(targetAuth.Metadata, "proxy_url")
|
||||||
|
} else {
|
||||||
|
targetAuth.Metadata["proxy_url"] = proxyURL
|
||||||
|
}
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
|
if len(req.Headers) > 0 {
|
||||||
|
existingHeaders := coreauth.ExtractCustomHeadersFromMetadata(targetAuth.Metadata)
|
||||||
|
nextHeaders := make(map[string]string, len(existingHeaders))
|
||||||
|
for k, v := range existingHeaders {
|
||||||
|
nextHeaders[k] = v
|
||||||
|
}
|
||||||
|
headerChanged := false
|
||||||
|
|
||||||
|
for key, value := range req.Headers {
|
||||||
|
name := strings.TrimSpace(key)
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
val := strings.TrimSpace(value)
|
||||||
|
attrKey := "header:" + name
|
||||||
|
if val == "" {
|
||||||
|
if _, ok := nextHeaders[name]; ok {
|
||||||
|
delete(nextHeaders, name)
|
||||||
|
headerChanged = true
|
||||||
|
}
|
||||||
|
if targetAuth.Attributes != nil {
|
||||||
|
if _, ok := targetAuth.Attributes[attrKey]; ok {
|
||||||
|
headerChanged = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if prev, ok := nextHeaders[name]; !ok || prev != val {
|
||||||
|
headerChanged = true
|
||||||
|
}
|
||||||
|
nextHeaders[name] = val
|
||||||
|
if targetAuth.Attributes != nil {
|
||||||
|
if prev, ok := targetAuth.Attributes[attrKey]; !ok || prev != val {
|
||||||
|
headerChanged = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
headerChanged = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if headerChanged {
|
||||||
|
if targetAuth.Metadata == nil {
|
||||||
|
targetAuth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
if targetAuth.Attributes == nil {
|
||||||
|
targetAuth.Attributes = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range req.Headers {
|
||||||
|
name := strings.TrimSpace(key)
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
val := strings.TrimSpace(value)
|
||||||
|
attrKey := "header:" + name
|
||||||
|
if val == "" {
|
||||||
|
delete(nextHeaders, name)
|
||||||
|
delete(targetAuth.Attributes, attrKey)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nextHeaders[name] = val
|
||||||
|
targetAuth.Attributes[attrKey] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nextHeaders) == 0 {
|
||||||
|
delete(targetAuth.Metadata, "headers")
|
||||||
|
} else {
|
||||||
|
metaHeaders := make(map[string]any, len(nextHeaders))
|
||||||
|
for k, v := range nextHeaders {
|
||||||
|
metaHeaders[k] = v
|
||||||
|
}
|
||||||
|
targetAuth.Metadata["headers"] = metaHeaders
|
||||||
|
}
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
if req.Priority != nil || req.Note != nil {
|
if req.Priority != nil || req.Note != nil {
|
||||||
if targetAuth.Metadata == nil {
|
if targetAuth.Metadata == nil {
|
||||||
targetAuth.Metadata = make(map[string]any)
|
targetAuth.Metadata = make(map[string]any)
|
||||||
@@ -2138,9 +2234,6 @@ func (h *Handler) RequestGitLabToken(c *gin.Context) {
|
|||||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||||
metadata["auth_kind"] = "oauth"
|
metadata["auth_kind"] = "oauth"
|
||||||
metadata["oauth_client_id"] = clientID
|
metadata["oauth_client_id"] = clientID
|
||||||
if clientSecret != "" {
|
|
||||||
metadata["oauth_client_secret"] = clientSecret
|
|
||||||
}
|
|
||||||
metadata["username"] = strings.TrimSpace(user.Username)
|
metadata["username"] = strings.TrimSpace(user.Username)
|
||||||
if email := primaryGitLabEmail(user); email != "" {
|
if email := primaryGitLabEmail(user); email != "" {
|
||||||
metadata["email"] = email
|
metadata["email"] = email
|
||||||
|
|||||||
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPatchAuthFileFields_MergeHeadersAndDeleteEmptyValues(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
manager := coreauth.NewManager(store, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: "test.json",
|
||||||
|
FileName: "test.json",
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": "/tmp/test.json",
|
||||||
|
"header:X-Old": "old",
|
||||||
|
"header:X-Remove": "gone",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "claude",
|
||||||
|
"headers": map[string]any{
|
||||||
|
"X-Old": "old",
|
||||||
|
"X-Remove": "gone",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||||
|
|
||||||
|
body := `{"name":"test.json","prefix":"p1","proxy_url":"http://proxy.local","headers":{"X-Old":"new","X-New":"v","X-Remove":" ","X-Nope":""}}`
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
ctx.Request = req
|
||||||
|
h.PatchAuthFileFields(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, ok := manager.GetByID("test.json")
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth record to exist after patch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated.Prefix != "p1" {
|
||||||
|
t.Fatalf("prefix = %q, want %q", updated.Prefix, "p1")
|
||||||
|
}
|
||||||
|
if updated.ProxyURL != "http://proxy.local" {
|
||||||
|
t.Fatalf("proxy_url = %q, want %q", updated.ProxyURL, "http://proxy.local")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated.Metadata == nil {
|
||||||
|
t.Fatalf("expected metadata to be non-nil")
|
||||||
|
}
|
||||||
|
if got, _ := updated.Metadata["prefix"].(string); got != "p1" {
|
||||||
|
t.Fatalf("metadata.prefix = %q, want %q", got, "p1")
|
||||||
|
}
|
||||||
|
if got, _ := updated.Metadata["proxy_url"].(string); got != "http://proxy.local" {
|
||||||
|
t.Fatalf("metadata.proxy_url = %q, want %q", got, "http://proxy.local")
|
||||||
|
}
|
||||||
|
|
||||||
|
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
raw, _ := json.Marshal(updated.Metadata["headers"])
|
||||||
|
t.Fatalf("metadata.headers = %T (%s), want map[string]any", updated.Metadata["headers"], string(raw))
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-Old"]; got != "new" {
|
||||||
|
t.Fatalf("metadata.headers.X-Old = %#v, want %q", got, "new")
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-New"]; got != "v" {
|
||||||
|
t.Fatalf("metadata.headers.X-New = %#v, want %q", got, "v")
|
||||||
|
}
|
||||||
|
if _, ok := headersMeta["X-Remove"]; ok {
|
||||||
|
t.Fatalf("expected metadata.headers.X-Remove to be deleted")
|
||||||
|
}
|
||||||
|
if _, ok := headersMeta["X-Nope"]; ok {
|
||||||
|
t.Fatalf("expected metadata.headers.X-Nope to be absent")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := updated.Attributes["header:X-Old"]; got != "new" {
|
||||||
|
t.Fatalf("attrs header:X-Old = %q, want %q", got, "new")
|
||||||
|
}
|
||||||
|
if got := updated.Attributes["header:X-New"]; got != "v" {
|
||||||
|
t.Fatalf("attrs header:X-New = %q, want %q", got, "v")
|
||||||
|
}
|
||||||
|
if _, ok := updated.Attributes["header:X-Remove"]; ok {
|
||||||
|
t.Fatalf("expected attrs header:X-Remove to be deleted")
|
||||||
|
}
|
||||||
|
if _, ok := updated.Attributes["header:X-Nope"]; ok {
|
||||||
|
t.Fatalf("expected attrs header:X-Nope to be absent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
manager := coreauth.NewManager(store, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: "noop.json",
|
||||||
|
FileName: "noop.json",
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": "/tmp/noop.json",
|
||||||
|
"header:X-Kee": "1",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "claude",
|
||||||
|
"headers": map[string]any{
|
||||||
|
"X-Kee": "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||||
|
|
||||||
|
body := `{"name":"noop.json","note":"hello","headers":{}}`
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
ctx.Request = req
|
||||||
|
h.PatchAuthFileFields(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, ok := manager.GetByID("noop.json")
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth record to exist after patch")
|
||||||
|
}
|
||||||
|
if got := updated.Attributes["header:X-Kee"]; got != "1" {
|
||||||
|
t.Fatalf("attrs header:X-Kee = %q, want %q", got, "1")
|
||||||
|
}
|
||||||
|
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected metadata.headers to remain a map, got %T", updated.Metadata["headers"])
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-Kee"]; got != "1" {
|
||||||
|
t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -214,19 +214,46 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
||||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.GeminiKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||||
|
for _, v := range h.cfg.GeminiKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
if len(out) != len(h.cfg.GeminiKey) {
|
||||||
|
h.cfg.GeminiKey = out
|
||||||
|
h.cfg.SanitizeGeminiKeys()
|
||||||
|
h.persist(c)
|
||||||
|
} else {
|
||||||
|
c.JSON(404, gin.H{"error": "item not found"})
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if len(out) != len(h.cfg.GeminiKey) {
|
|
||||||
h.cfg.GeminiKey = out
|
matchIndex := -1
|
||||||
h.cfg.SanitizeGeminiKeys()
|
matchCount := 0
|
||||||
h.persist(c)
|
for i := range h.cfg.GeminiKey {
|
||||||
} else {
|
if strings.TrimSpace(h.cfg.GeminiKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount == 0 {
|
||||||
c.JSON(404, gin.H{"error": "item not found"})
|
c.JSON(404, gin.H{"error": "item not found"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:matchIndex], h.cfg.GeminiKey[matchIndex+1:]...)
|
||||||
|
h.cfg.SanitizeGeminiKeys()
|
||||||
|
h.persist(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if idxStr := c.Query("index"); idxStr != "" {
|
if idxStr := c.Query("index"); idxStr != "" {
|
||||||
@@ -335,14 +362,39 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
||||||
if val := c.Query("api-key"); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.ClaudeKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
||||||
|
for _, v := range h.cfg.ClaudeKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.ClaudeKey = out
|
||||||
|
h.cfg.SanitizeClaudeKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.ClaudeKey {
|
||||||
|
if strings.TrimSpace(h.cfg.ClaudeKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:matchIndex], h.cfg.ClaudeKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.ClaudeKey = out
|
|
||||||
h.cfg.SanitizeClaudeKeys()
|
h.cfg.SanitizeClaudeKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
@@ -601,13 +653,38 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
||||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.VertexCompatAPIKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
||||||
|
for _, v := range h.cfg.VertexCompatAPIKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.VertexCompatAPIKey = out
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.VertexCompatAPIKey {
|
||||||
|
if strings.TrimSpace(h.cfg.VertexCompatAPIKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:matchIndex], h.cfg.VertexCompatAPIKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.VertexCompatAPIKey = out
|
|
||||||
h.cfg.SanitizeVertexCompatKeys()
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
@@ -919,14 +996,39 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
||||||
if val := c.Query("api-key"); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.CodexKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||||
|
for _, v := range h.cfg.CodexKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.CodexKey = out
|
||||||
|
h.cfg.SanitizeCodexKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.CodexKey {
|
||||||
|
if strings.TrimSpace(h.cfg.CodexKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.CodexKey = append(h.cfg.CodexKey[:matchIndex], h.cfg.CodexKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.CodexKey = out
|
|
||||||
h.cfg.SanitizeCodexKeys()
|
h.cfg.SanitizeCodexKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -0,0 +1,172 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeTestConfigFile(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if errWrite := os.WriteFile(path, []byte("{}\n"), 0o600); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write test config: %v", errWrite)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteGeminiKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
GeminiKey: []config.GeminiKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key", nil)
|
||||||
|
|
||||||
|
h.DeleteGeminiKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.GeminiKey); got != 2 {
|
||||||
|
t.Fatalf("gemini keys len = %d, want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteGeminiKey_DeletesOnlyMatchingBaseURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
GeminiKey: []config.GeminiKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key&base-url=https://a.example.com", nil)
|
||||||
|
|
||||||
|
h.DeleteGeminiKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.GeminiKey); got != 1 {
|
||||||
|
t.Fatalf("gemini keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.GeminiKey[0].BaseURL; got != "https://b.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://b.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteClaudeKey_DeletesEmptyBaseURLWhenExplicitlyProvided(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
ClaudeKey: []config.ClaudeKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: ""},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://claude.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/claude-api-key?api-key=shared-key&base-url=", nil)
|
||||||
|
|
||||||
|
h.DeleteClaudeKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.ClaudeKey); got != 1 {
|
||||||
|
t.Fatalf("claude keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.ClaudeKey[0].BaseURL; got != "https://claude.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://claude.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteVertexCompatKey_DeletesOnlyMatchingBaseURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/vertex-api-key?api-key=shared-key&base-url=https://b.example.com", nil)
|
||||||
|
|
||||||
|
h.DeleteVertexCompatKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.VertexCompatAPIKey); got != 1 {
|
||||||
|
t.Fatalf("vertex keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.VertexCompatAPIKey[0].BaseURL; got != "https://a.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://a.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteCodexKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
CodexKey: []config.CodexKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/codex-api-key?api-key=shared-key", nil)
|
||||||
|
|
||||||
|
h.DeleteCodexKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.CodexKey); got != 2 {
|
||||||
|
t.Fatalf("codex keys len = %d, want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,6 +15,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||||
|
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
|
||||||
|
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
|
||||||
|
|
||||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||||
type RequestInfo struct {
|
type RequestInfo struct {
|
||||||
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
if len(apiResponse) > 0 {
|
if len(apiResponse) > 0 {
|
||||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||||
}
|
}
|
||||||
|
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
|
||||||
|
if len(apiWebsocketTimeline) > 0 {
|
||||||
|
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
|
||||||
|
}
|
||||||
if err := w.streamWriter.Close(); err != nil {
|
if err := w.streamWriter.Close(); err != nil {
|
||||||
w.streamWriter = nil
|
w.streamWriter = nil
|
||||||
return err
|
return err
|
||||||
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||||
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
|
||||||
|
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
|
||||||
|
if !isExist {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, ok := apiTimeline.([]byte)
|
||||||
|
if !ok || len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.Clone(data)
|
||||||
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
||||||
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
||||||
if !isExist {
|
if !isExist {
|
||||||
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||||
if c != nil {
|
if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
|
||||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
return body
|
||||||
switch value := bodyOverride.(type) {
|
|
||||||
case []byte:
|
|
||||||
if len(value) > 0 {
|
|
||||||
return bytes.Clone(value)
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
if strings.TrimSpace(value) != "" {
|
|
||||||
return []byte(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||||
return w.requestInfo.Body
|
return w.requestInfo.Body
|
||||||
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
|
||||||
|
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if w.body == nil || w.body.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.Clone(w.body.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
|
||||||
|
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBodyOverride(c *gin.Context, key string) []byte {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bodyOverride, isExist := c.Get(key)
|
||||||
|
if !isExist {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch value := bodyOverride.(type) {
|
||||||
|
case []byte:
|
||||||
|
if len(value) > 0 {
|
||||||
|
return bytes.Clone(value)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(value) != "" {
|
||||||
|
return []byte(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||||
if w.requestInfo == nil {
|
if w.requestInfo == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if loggerWithOptions, ok := w.logger.(interface {
|
if loggerWithOptions, ok := w.logger.(interface {
|
||||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||||
}); ok {
|
}); ok {
|
||||||
return loggerWithOptions.LogRequestWithOptions(
|
return loggerWithOptions.LogRequestWithOptions(
|
||||||
w.requestInfo.URL,
|
w.requestInfo.URL,
|
||||||
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
|||||||
statusCode,
|
statusCode,
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
|
websocketTimeline,
|
||||||
apiRequestBody,
|
apiRequestBody,
|
||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
forceLog,
|
forceLog,
|
||||||
w.requestInfo.RequestID,
|
w.requestInfo.RequestID,
|
||||||
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
|||||||
statusCode,
|
statusCode,
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
|
websocketTimeline,
|
||||||
apiRequestBody,
|
apiRequestBody,
|
||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
w.requestInfo.RequestID,
|
w.requestInfo.RequestID,
|
||||||
w.requestInfo.Timestamp,
|
w.requestInfo.Timestamp,
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||||
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
|||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
wrapper := &ResponseWriterWrapper{}
|
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||||
|
|
||||||
body := wrapper.extractRequestBody(c)
|
body := wrapper.extractRequestBody(c)
|
||||||
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
|||||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||||
|
wrapper.body.WriteString("original-response")
|
||||||
|
|
||||||
|
body := wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "original-response" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "original-response")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
|
||||||
|
body = wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "override-response" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "override-response")
|
||||||
|
}
|
||||||
|
|
||||||
|
body[0] = 'X'
|
||||||
|
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
|
||||||
|
t.Fatalf("response override should be cloned, got %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{}
|
||||||
|
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
|
||||||
|
|
||||||
|
body := wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "override-response-as-string" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
override := []byte("body-override")
|
||||||
|
c.Set(requestBodyOverrideContextKey, override)
|
||||||
|
|
||||||
|
body := extractBodyOverride(c, requestBodyOverrideContextKey)
|
||||||
|
if !bytes.Equal(body, override) {
|
||||||
|
t.Fatalf("body override = %q, want %q", string(body), string(override))
|
||||||
|
}
|
||||||
|
|
||||||
|
body[0] = 'X'
|
||||||
|
if !bytes.Equal(override, []byte("body-override")) {
|
||||||
|
t.Fatalf("override mutated: %q", string(override))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{}
|
||||||
|
if got := wrapper.extractWebsocketTimeline(c); got != nil {
|
||||||
|
t.Fatalf("expected nil websocket timeline, got %q", string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
|
||||||
|
body := wrapper.extractWebsocketTimeline(c)
|
||||||
|
if string(body) != "timeline" {
|
||||||
|
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
streamWriter := &testStreamingLogWriter{}
|
||||||
|
wrapper := &ResponseWriterWrapper{
|
||||||
|
ResponseWriter: c.Writer,
|
||||||
|
logger: &testRequestLogger{enabled: true},
|
||||||
|
requestInfo: &RequestInfo{
|
||||||
|
URL: "/v1/responses",
|
||||||
|
Method: "POST",
|
||||||
|
Headers: map[string][]string{"Content-Type": {"application/json"}},
|
||||||
|
RequestID: "req-1",
|
||||||
|
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
|
||||||
|
},
|
||||||
|
isStreaming: true,
|
||||||
|
streamWriter: streamWriter,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
|
||||||
|
|
||||||
|
if err := wrapper.Finalize(c); err != nil {
|
||||||
|
t.Fatalf("Finalize error: %v", err)
|
||||||
|
}
|
||||||
|
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
|
||||||
|
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
|
||||||
|
}
|
||||||
|
if !streamWriter.closed {
|
||||||
|
t.Fatal("expected stream writer to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRequestLogger struct {
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
|
||||||
|
return &testStreamingLogWriter{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) IsEnabled() bool {
|
||||||
|
return l.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
type testStreamingLogWriter struct {
|
||||||
|
apiWebsocketTimeline []byte
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||||
|
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) Close() error {
|
||||||
|
w.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanitize request body: remove thinking blocks with invalid signatures
|
||||||
|
// to prevent upstream API 400 errors
|
||||||
|
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
|
||||||
|
|
||||||
// Restore the body for the handler to read
|
// Restore the body for the handler to read
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|
||||||
@@ -249,6 +253,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
|
rewriter.suppressThinking = true
|
||||||
c.Writer = rewriter
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
@@ -259,10 +264,17 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
} else if len(providers) > 0 {
|
} else if len(providers) > 0 {
|
||||||
// Log: Using local provider (free)
|
// Log: Using local provider (free)
|
||||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
// Wrap with ResponseRewriter for local providers too, because upstream
|
||||||
|
// proxies (e.g. NewAPI) may return a different model name and lack
|
||||||
|
// Amp-required fields like thinking.signature.
|
||||||
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
|
rewriter.suppressThinking = providerName != "claude"
|
||||||
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
handler(c)
|
handler(c)
|
||||||
|
rewriter.Flush()
|
||||||
} else {
|
} else {
|
||||||
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|||||||
@@ -129,11 +129,11 @@ func TestModifyResponse_GzipScenarios(t *testing.T) {
|
|||||||
wantCE: "",
|
wantCE: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skips_non_2xx_status",
|
name: "decompresses_non_2xx_status_when_gzip_detected",
|
||||||
header: http.Header{},
|
header: http.Header{},
|
||||||
body: good,
|
body: good,
|
||||||
status: 404,
|
status: 404,
|
||||||
wantBody: good,
|
wantBody: goodJSON,
|
||||||
wantCE: "",
|
wantCE: "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package amp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -12,15 +14,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
||||||
// It's used to rewrite model names in responses when model mapping is used
|
// It is used to rewrite model names in responses when model mapping is used
|
||||||
|
// and to keep Amp-compatible response shapes.
|
||||||
type ResponseRewriter struct {
|
type ResponseRewriter struct {
|
||||||
gin.ResponseWriter
|
gin.ResponseWriter
|
||||||
body *bytes.Buffer
|
body *bytes.Buffer
|
||||||
originalModel string
|
originalModel string
|
||||||
isStreaming bool
|
isStreaming bool
|
||||||
|
suppressThinking bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponseRewriter creates a new response rewriter for model name substitution
|
// NewResponseRewriter creates a new response rewriter for model name substitution.
|
||||||
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
||||||
return &ResponseRewriter{
|
return &ResponseRewriter{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
@@ -33,15 +37,15 @@ const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
|||||||
|
|
||||||
func looksLikeSSEChunk(data []byte) bool {
|
func looksLikeSSEChunk(data []byte) bool {
|
||||||
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
||||||
// Heuristics are intentionally simple and cheap.
|
// We conservatively detect SSE by checking for "data:" / "event:" at the start of any line.
|
||||||
return bytes.Contains(data, []byte("data:")) ||
|
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||||
bytes.Contains(data, []byte("event:")) ||
|
trimmed := bytes.TrimSpace(line)
|
||||||
bytes.Contains(data, []byte("message_start")) ||
|
if bytes.HasPrefix(trimmed, []byte("data:")) ||
|
||||||
bytes.Contains(data, []byte("message_delta")) ||
|
bytes.HasPrefix(trimmed, []byte("event:")) {
|
||||||
bytes.Contains(data, []byte("content_block_start")) ||
|
return true
|
||||||
bytes.Contains(data, []byte("content_block_delta")) ||
|
}
|
||||||
bytes.Contains(data, []byte("content_block_stop")) ||
|
}
|
||||||
bytes.Contains(data, []byte("\n\n"))
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||||
@@ -95,7 +99,8 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rw.isStreaming {
|
if rw.isStreaming {
|
||||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
rewritten := rw.rewriteStreamChunk(data)
|
||||||
|
n, err := rw.ResponseWriter.Write(rewritten)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -106,7 +111,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
|||||||
return rw.body.Write(data)
|
return rw.body.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush writes the buffered response with model names rewritten
|
|
||||||
func (rw *ResponseRewriter) Flush() {
|
func (rw *ResponseRewriter) Flush() {
|
||||||
if rw.isStreaming {
|
if rw.isStreaming {
|
||||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
@@ -115,40 +119,79 @@ func (rw *ResponseRewriter) Flush() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if rw.body.Len() > 0 {
|
if rw.body.Len() > 0 {
|
||||||
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||||
|
// Update Content-Length to match the rewritten body size, since
|
||||||
|
// signature injection and model name changes alter the payload length.
|
||||||
|
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
|
||||||
|
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
|
||||||
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// modelFieldPaths lists all JSON paths where model name may appear
|
|
||||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||||
|
|
||||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
func ensureAmpSignature(data []byte) []byte {
|
||||||
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||||
// The Amp client struggles when both thinking and tool_use blocks are present
|
blockType := block.Get("type").String()
|
||||||
|
if blockType != "tool_use" && blockType != "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
signaturePath := fmt.Sprintf("content.%d.signature", index)
|
||||||
|
if gjson.GetBytes(data, signaturePath).Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, signaturePath, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
|
||||||
|
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, "content_block.signature", "")
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||||
|
if !rw.suppressThinking {
|
||||||
|
return data
|
||||||
|
}
|
||||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||||
if filtered.Exists() {
|
if filtered.Exists() {
|
||||||
originalCount := gjson.GetBytes(data, "content.#").Int()
|
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||||
filteredCount := filtered.Get("#").Int()
|
filteredCount := filtered.Get("#").Int()
|
||||||
|
|
||||||
if originalCount > filteredCount {
|
if originalCount > filteredCount {
|
||||||
var err error
|
var err error
|
||||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||||
} else {
|
|
||||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
|
||||||
// Log the result for verification
|
|
||||||
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
|
data = ensureAmpSignature(data)
|
||||||
|
data = rw.suppressAmpThinking(data)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
if rw.originalModel == "" {
|
if rw.originalModel == "" {
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
@@ -160,24 +203,164 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
|
||||||
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||||
if rw.originalModel == "" {
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
return chunk
|
var out [][]byte
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(lines) {
|
||||||
|
line := lines[i]
|
||||||
|
trimmed := bytes.TrimSpace(line)
|
||||||
|
|
||||||
|
// Case 1: "event:" line - look ahead for its "data:" line
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("event: ")) {
|
||||||
|
// Scan forward past blank lines to find the data: line
|
||||||
|
dataIdx := -1
|
||||||
|
for j := i + 1; j < len(lines); j++ {
|
||||||
|
t := bytes.TrimSpace(lines[j])
|
||||||
|
if len(t) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(t, []byte("data: ")) {
|
||||||
|
dataIdx = j
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataIdx >= 0 {
|
||||||
|
// Found event+data pair - process through rewriter
|
||||||
|
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||||
|
if rewritten == nil {
|
||||||
|
i = dataIdx + 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Emit event line
|
||||||
|
out = append(out, line)
|
||||||
|
// Emit blank lines between event and data
|
||||||
|
for k := i + 1; k < dataIdx; k++ {
|
||||||
|
out = append(out, lines[k])
|
||||||
|
}
|
||||||
|
// Emit rewritten data
|
||||||
|
out = append(out, append([]byte("data: "), rewritten...))
|
||||||
|
i = dataIdx + 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No data line found (orphan event from cross-chunk split)
|
||||||
|
// Pass it through as-is - the data will arrive in the next chunk
|
||||||
|
out = append(out, line)
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 2: standalone "data:" line (no preceding event: in this chunk)
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("data: ")) {
|
||||||
|
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||||
|
if rewritten != nil {
|
||||||
|
out = append(out, append([]byte("data: "), rewritten...))
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 3: everything else
|
||||||
|
out = append(out, line)
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSE format: "data: {json}\n\n"
|
return bytes.Join(out, []byte("\n"))
|
||||||
lines := bytes.Split(chunk, []byte("\n"))
|
}
|
||||||
for i, line := range lines {
|
|
||||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
// rewriteStreamEvent processes a single JSON event in the SSE stream.
|
||||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
// It rewrites model names and ensures signature fields exist.
|
||||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
// NOTE: streaming mode does NOT suppress thinking blocks - they are
|
||||||
// Rewrite JSON in the data line
|
// passed through with signature injection to avoid breaking SSE index
|
||||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
// alignment and TUI rendering.
|
||||||
lines[i] = append([]byte("data: "), rewritten...)
|
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||||
|
// Inject empty signature where needed
|
||||||
|
data = ensureAmpSignature(data)
|
||||||
|
|
||||||
|
// Rewrite model name
|
||||||
|
if rw.originalModel != "" {
|
||||||
|
for _, path := range modelFieldPaths {
|
||||||
|
if gjson.GetBytes(data, path).Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return bytes.Join(lines, []byte("\n"))
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
|
||||||
|
// and strips the proxy-injected "signature" field from tool_use blocks in the messages
|
||||||
|
// array before forwarding to the upstream API.
|
||||||
|
// This prevents 400 errors from the API which requires valid signatures on thinking
|
||||||
|
// blocks and does not accept a signature field on tool_use blocks.
|
||||||
|
func SanitizeAmpRequestBody(body []byte) []byte {
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
for msgIdx, msg := range messages.Array() {
|
||||||
|
if msg.Get("role").String() != "assistant" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepBlocks []interface{}
|
||||||
|
contentModified := false
|
||||||
|
|
||||||
|
for _, block := range content.Array() {
|
||||||
|
blockType := block.Get("type").String()
|
||||||
|
if blockType == "thinking" {
|
||||||
|
sig := block.Get("signature")
|
||||||
|
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
|
||||||
|
contentModified = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use raw JSON to prevent float64 rounding of large integers in tool_use inputs
|
||||||
|
blockRaw := []byte(block.Raw)
|
||||||
|
if blockType == "tool_use" && block.Get("signature").Exists() {
|
||||||
|
blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature")
|
||||||
|
contentModified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw
|
||||||
|
keepBlocks = append(keepBlocks, json.RawMessage(blockRaw))
|
||||||
|
}
|
||||||
|
|
||||||
|
if contentModified {
|
||||||
|
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
|
||||||
|
var err error
|
||||||
|
if len(keepBlocks) == 0 {
|
||||||
|
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
|
||||||
|
} else {
|
||||||
|
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if modified {
|
||||||
|
log.Debugf("Amp RequestSanitizer: sanitized request body")
|
||||||
|
}
|
||||||
|
return body
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package amp
|
package amp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -100,6 +101,80 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{}
|
||||||
|
|
||||||
|
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
|
||||||
|
result := rw.rewriteStreamChunk(chunk)
|
||||||
|
|
||||||
|
// Streaming mode preserves thinking blocks (does NOT suppress them)
|
||||||
|
// to avoid breaking SSE index alignment and TUI rendering
|
||||||
|
if !contains(result, []byte(`"content_block":{"type":"thinking"`)) {
|
||||||
|
t.Fatalf("expected thinking content_block_start to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"delta":{"type":"thinking_delta"`)) {
|
||||||
|
t.Fatalf("expected thinking_delta to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"type":"content_block_stop","index":0`)) {
|
||||||
|
t.Fatalf("expected content_block_stop for thinking block to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"content_block":{"type":"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
// Signature should be injected into both thinking and tool_use blocks
|
||||||
|
if count := strings.Count(string(result), `"signature":""`); count != 2 {
|
||||||
|
t.Fatalf("expected 2 signature injections, but got %d in %s", count, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte("drop-whitespace")) {
|
||||||
|
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if contains(result, []byte("drop-number")) {
|
||||||
|
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte("keep-valid")) {
|
||||||
|
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte("keep-text")) {
|
||||||
|
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte(`"signature":""`)) {
|
||||||
|
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"valid-sig"`)) {
|
||||||
|
t.Fatalf("expected thinking signature to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte("drop-me")) {
|
||||||
|
t.Fatalf("expected invalid thinking block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if contains(result, []byte(`"signature"`)) {
|
||||||
|
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func contains(data, substr []byte) bool {
|
func contains(data, substr []byte) bool {
|
||||||
for i := 0; i <= len(data)-len(substr); i++ {
|
for i := 0; i <= len(data)-len(substr); i++ {
|
||||||
if string(data[i:i+len(substr)]) == string(substr) {
|
if string(data[i:i+len(substr)]) == string(substr) {
|
||||||
|
|||||||
@@ -323,6 +323,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
// setupRoutes configures the API routes for the server.
|
// setupRoutes configures the API routes for the server.
|
||||||
// It defines the endpoints and associates them with their respective handlers.
|
// It defines the endpoints and associates them with their respective handlers.
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
|
s.engine.GET("/healthz", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
||||||
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
||||||
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
||||||
@@ -569,6 +573,8 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
||||||
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
||||||
|
|
||||||
|
mgmt.GET("/copilot-quota", s.mgmt.GetCopilotQuota)
|
||||||
|
|
||||||
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
||||||
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
||||||
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@@ -46,6 +47,28 @@ func newTestServer(t *testing.T) *Server {
|
|||||||
return NewServer(cfg, authManager, accessManager, configPath)
|
return NewServer(cfg, authManager, accessManager, configPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHealthz(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String())
|
||||||
|
}
|
||||||
|
if resp.Status != "ok" {
|
||||||
|
t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAmpProviderModelRoutes(t *testing.T) {
|
func TestAmpProviderModelRoutes(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -172,6 +195,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
true,
|
true,
|
||||||
"issue-1711",
|
"issue-1711",
|
||||||
time.Now(),
|
time.Now(),
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
|||||||
"client_id": {ClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {RedirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"org:create_api_key user:profile user:inference"},
|
"scope": {"user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
"code_challenge_method": {"S256"},
|
"code_challenge_method": {"S256"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
|
|||||||
@@ -235,6 +235,74 @@ type CopilotModelEntry struct {
|
|||||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopilotModelLimits holds the token limits returned by the Copilot /models API
|
||||||
|
// under capabilities.limits. These limits vary by account type (individual vs
|
||||||
|
// business) and are the authoritative source for enforcing prompt size.
|
||||||
|
type CopilotModelLimits struct {
|
||||||
|
// MaxContextWindowTokens is the total context window (prompt + output).
|
||||||
|
MaxContextWindowTokens int
|
||||||
|
// MaxPromptTokens is the hard limit on input/prompt tokens.
|
||||||
|
// Exceeding this triggers a 400 error from the Copilot API.
|
||||||
|
MaxPromptTokens int
|
||||||
|
// MaxOutputTokens is the maximum number of output/completion tokens.
|
||||||
|
MaxOutputTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limits extracts the token limits from the model's capabilities map.
|
||||||
|
// Returns nil if no limits are available or the structure is unexpected.
|
||||||
|
//
|
||||||
|
// Expected Copilot API shape:
|
||||||
|
//
|
||||||
|
// "capabilities": {
|
||||||
|
// "limits": {
|
||||||
|
// "max_context_window_tokens": 200000,
|
||||||
|
// "max_prompt_tokens": 168000,
|
||||||
|
// "max_output_tokens": 32000
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
func (e *CopilotModelEntry) Limits() *CopilotModelLimits {
|
||||||
|
if e.Capabilities == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
limitsRaw, ok := e.Capabilities["limits"]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
limitsMap, ok := limitsRaw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &CopilotModelLimits{
|
||||||
|
MaxContextWindowTokens: anyToInt(limitsMap["max_context_window_tokens"]),
|
||||||
|
MaxPromptTokens: anyToInt(limitsMap["max_prompt_tokens"]),
|
||||||
|
MaxOutputTokens: anyToInt(limitsMap["max_output_tokens"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only return if at least one field is populated.
|
||||||
|
if result.MaxContextWindowTokens == 0 && result.MaxPromptTokens == 0 && result.MaxOutputTokens == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// anyToInt converts a JSON-decoded numeric value to int.
|
||||||
|
// Go's encoding/json decodes numbers into float64 when the target is any/interface{}.
|
||||||
|
func anyToInt(v any) int {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
return int(n)
|
||||||
|
case float32:
|
||||||
|
return int(n)
|
||||||
|
case int:
|
||||||
|
return n
|
||||||
|
case int64:
|
||||||
|
return int(n)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
||||||
type CopilotModelsResponse struct {
|
type CopilotModelsResponse struct {
|
||||||
Data []CopilotModelEntry `json:"data"`
|
Data []CopilotModelEntry `json:"data"`
|
||||||
|
|||||||
@@ -30,6 +30,10 @@ type VertexCredentialStorage struct {
|
|||||||
|
|
||||||
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Prefix optionally namespaces models for this credential (e.g., "teamA").
|
||||||
|
// This results in model names like "teamA/gemini-2.0-flash".
|
||||||
|
Prefix string `json:"prefix,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
||||||
// it as a "vertex" provider credential. The file content is embedded in the auth
|
// it as a "vertex" provider credential. The file content is embedded in the auth
|
||||||
// file to allow portable deployment across stores.
|
// file to allow portable deployment across stores.
|
||||||
func DoVertexImport(cfg *config.Config, keyPath string) {
|
func DoVertexImport(cfg *config.Config, keyPath string, prefix string) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
cfg = &config.Config{}
|
cfg = &config.Config{}
|
||||||
}
|
}
|
||||||
@@ -62,13 +62,28 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
|||||||
// Default location if not provided by user. Can be edited in the saved file later.
|
// Default location if not provided by user. Can be edited in the saved file later.
|
||||||
location := "us-central1"
|
location := "us-central1"
|
||||||
|
|
||||||
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
|
// Normalize and validate prefix: must be a single segment (no "/" allowed).
|
||||||
|
prefix = strings.TrimSpace(prefix)
|
||||||
|
prefix = strings.Trim(prefix, "/")
|
||||||
|
if prefix != "" && strings.Contains(prefix, "/") {
|
||||||
|
log.Errorf("vertex-import: prefix must be a single segment (no '/' allowed): %q", prefix)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include prefix in filename so importing the same project with different
|
||||||
|
// prefixes creates separate credential files instead of overwriting.
|
||||||
|
baseName := sanitizeFilePart(projectID)
|
||||||
|
if prefix != "" {
|
||||||
|
baseName = sanitizeFilePart(prefix) + "-" + baseName
|
||||||
|
}
|
||||||
|
fileName := fmt.Sprintf("vertex-%s.json", baseName)
|
||||||
// Build auth record
|
// Build auth record
|
||||||
storage := &vertex.VertexCredentialStorage{
|
storage := &vertex.VertexCredentialStorage{
|
||||||
ServiceAccount: sa,
|
ServiceAccount: sa,
|
||||||
ProjectID: projectID,
|
ProjectID: projectID,
|
||||||
Email: email,
|
Email: email,
|
||||||
Location: location,
|
Location: location,
|
||||||
|
Prefix: prefix,
|
||||||
}
|
}
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"service_account": sa,
|
"service_account": sa,
|
||||||
@@ -76,6 +91,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
|||||||
"email": email,
|
"email": email,
|
||||||
"location": location,
|
"location": location,
|
||||||
"type": "vertex",
|
"type": "vertex",
|
||||||
|
"prefix": prefix,
|
||||||
"label": labelForVertex(projectID, email),
|
"label": labelForVertex(projectID, email),
|
||||||
}
|
}
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
|
|||||||
@@ -211,6 +211,10 @@ type QuotaExceeded struct {
|
|||||||
|
|
||||||
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
||||||
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
||||||
|
|
||||||
|
// AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once
|
||||||
|
// on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"].
|
||||||
|
AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoutingConfig configures how credentials are selected for requests.
|
// RoutingConfig configures how credentials are selected for requests.
|
||||||
@@ -257,8 +261,8 @@ type AmpCode struct {
|
|||||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||||
|
|
||||||
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||||
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||||
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
// is used for the upstream Amp request.
|
||||||
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||||
|
|
||||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||||
@@ -380,6 +384,11 @@ type ClaudeKey struct {
|
|||||||
|
|
||||||
// Cloak configures request cloaking for non-Claude-Code clients.
|
// Cloak configures request cloaking for non-Claude-Code clients.
|
||||||
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
||||||
|
|
||||||
|
// ExperimentalCCHSigning enables opt-in final-body cch signing for cloaked
|
||||||
|
// Claude /v1/messages requests. It is disabled by default so upstream seed
|
||||||
|
// changes do not alter the proxy's legacy behavior.
|
||||||
|
ExperimentalCCHSigning bool `yaml:"experimental-cch-signing,omitempty" json:"experimental-cch-signing,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
||||||
@@ -972,6 +981,7 @@ func (cfg *Config) SanitizeKiroKeys() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
||||||
|
// It uses API key + base URL as the uniqueness key.
|
||||||
func (cfg *Config) SanitizeGeminiKeys() {
|
func (cfg *Config) SanitizeGeminiKeys() {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return
|
return
|
||||||
@@ -990,10 +1000,11 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
|||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
if _, exists := seen[entry.APIKey]; exists {
|
uniqueKey := entry.APIKey + "|" + entry.BaseURL
|
||||||
|
if _, exists := seen[uniqueKey]; exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seen[entry.APIKey] = struct{}{}
|
seen[uniqueKey] = struct{}{}
|
||||||
out = append(out, entry)
|
out = append(out, entry)
|
||||||
}
|
}
|
||||||
cfg.GeminiKey = out
|
cfg.GeminiKey = out
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ type SDKConfig struct {
|
|||||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||||
|
|
||||||
|
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
|
||||||
|
// Default is false for safety; when false, /v1internal:* requests are rejected.
|
||||||
|
EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"`
|
||||||
|
|
||||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||||
// credentials as well.
|
// credentials as well.
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package logging
|
package logging
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/flate"
|
"compress/flate"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
@@ -41,15 +42,17 @@ type RequestLogger interface {
|
|||||||
// - statusCode: The response status code
|
// - statusCode: The response status code
|
||||||
// - responseHeaders: The response headers
|
// - responseHeaders: The response headers
|
||||||
// - response: The raw response data
|
// - response: The raw response data
|
||||||
|
// - websocketTimeline: Optional downstream websocket event timeline
|
||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
|
// - apiWebsocketTimeline: Optional upstream websocket event timeline
|
||||||
// - requestID: Optional request ID for log file naming
|
// - requestID: Optional request ID for log file naming
|
||||||
// - requestTimestamp: When the request was received
|
// - requestTimestamp: When the request was received
|
||||||
// - apiResponseTimestamp: When the API response was received
|
// - apiResponseTimestamp: When the API response was received
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||||
|
|
||||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||||
//
|
//
|
||||||
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
|
|||||||
// - error: An error if writing fails, nil otherwise
|
// - error: An error if writing fails, nil otherwise
|
||||||
WriteAPIResponse(apiResponse []byte) error
|
WriteAPIResponse(apiResponse []byte) error
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
|
||||||
|
// This should be called when upstream communication happened over websocket.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if writing fails, nil otherwise
|
||||||
|
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
|
||||||
|
|
||||||
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
if !l.enabled && !force {
|
if !l.enabled && !force {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
|||||||
requestHeaders,
|
requestHeaders,
|
||||||
body,
|
body,
|
||||||
requestBodyPath,
|
requestBodyPath,
|
||||||
|
websocketTimeline,
|
||||||
apiRequest,
|
apiRequest,
|
||||||
apiResponse,
|
apiResponse,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
statusCode,
|
statusCode,
|
||||||
responseHeaders,
|
responseHeaders,
|
||||||
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
requestHeaders map[string][]string,
|
requestHeaders map[string][]string,
|
||||||
requestBody []byte,
|
requestBody []byte,
|
||||||
requestBodyPath string,
|
requestBodyPath string,
|
||||||
|
websocketTimeline []byte,
|
||||||
apiRequest []byte,
|
apiRequest []byte,
|
||||||
apiResponse []byte,
|
apiResponse []byte,
|
||||||
|
apiWebsocketTimeline []byte,
|
||||||
apiResponseErrors []*interfaces.ErrorMessage,
|
apiResponseErrors []*interfaces.ErrorMessage,
|
||||||
statusCode int,
|
statusCode int,
|
||||||
responseHeaders map[string][]string,
|
responseHeaders map[string][]string,
|
||||||
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
if requestTimestamp.IsZero() {
|
if requestTimestamp.IsZero() {
|
||||||
requestTimestamp = time.Now()
|
requestTimestamp = time.Now()
|
||||||
}
|
}
|
||||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
|
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||||
|
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
|
||||||
|
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||||
|
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
||||||
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if isWebsocketTranscript {
|
||||||
|
// Intentionally omit the generic downstream HTTP response section for websocket
|
||||||
|
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
|
||||||
|
// and appending a one-off upgrade response snapshot would dilute that transcript.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
|
|||||||
body []byte,
|
body []byte,
|
||||||
bodyPath string,
|
bodyPath string,
|
||||||
timestamp time.Time,
|
timestamp time.Time,
|
||||||
|
downstreamTransport string,
|
||||||
|
upstreamTransport string,
|
||||||
|
includeBody bool,
|
||||||
) error {
|
) error {
|
||||||
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
|
|||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(downstreamTransport) != "" {
|
||||||
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(upstreamTransport) != "" {
|
||||||
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !includeBody {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bodyTrailingNewlines := 1
|
||||||
if bodyPath != "" {
|
if bodyPath != "" {
|
||||||
bodyFile, errOpen := os.Open(bodyPath)
|
bodyFile, errOpen := os.Open(bodyPath)
|
||||||
if errOpen != nil {
|
if errOpen != nil {
|
||||||
return errOpen
|
return errOpen
|
||||||
}
|
}
|
||||||
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
|
tracker := &trailingNewlineTrackingWriter{writer: w}
|
||||||
|
written, errCopy := io.Copy(tracker, bodyFile)
|
||||||
|
if errCopy != nil {
|
||||||
_ = bodyFile.Close()
|
_ = bodyFile.Close()
|
||||||
return errCopy
|
return errCopy
|
||||||
}
|
}
|
||||||
|
if written > 0 {
|
||||||
|
bodyTrailingNewlines = tracker.trailingNewlines
|
||||||
|
}
|
||||||
if errClose := bodyFile.Close(); errClose != nil {
|
if errClose := bodyFile.Close(); errClose != nil {
|
||||||
log.WithError(errClose).Warn("failed to close request body temp file")
|
log.WithError(errClose).Warn("failed to close request body temp file")
|
||||||
}
|
}
|
||||||
} else if _, errWrite := w.Write(body); errWrite != nil {
|
} else if _, errWrite := w.Write(body); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
|
} else if len(body) > 0 {
|
||||||
|
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
|
||||||
}
|
}
|
||||||
|
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
|
||||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func countTrailingNewlinesBytes(payload []byte) int {
|
||||||
|
count := 0
|
||||||
|
for i := len(payload) - 1; i >= 0; i-- {
|
||||||
|
if payload[i] != '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
|
||||||
|
missingNewlines := 3 - trailingNewlines
|
||||||
|
if missingNewlines <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
|
||||||
|
type trailingNewlineTrackingWriter struct {
|
||||||
|
writer io.Writer
|
||||||
|
trailingNewlines int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
|
||||||
|
written, errWrite := t.writer.Write(payload)
|
||||||
|
if written > 0 {
|
||||||
|
writtenPayload := payload[:written]
|
||||||
|
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
|
||||||
|
if trailingNewlines == len(writtenPayload) {
|
||||||
|
t.trailingNewlines += trailingNewlines
|
||||||
|
} else {
|
||||||
|
t.trailingNewlines = trailingNewlines
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return written, errWrite
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSectionPayload(payload []byte) bool {
|
||||||
|
return len(bytes.TrimSpace(payload)) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
|
||||||
|
if hasSectionPayload(websocketTimeline) {
|
||||||
|
return "websocket"
|
||||||
|
}
|
||||||
|
for key, values := range headers {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
|
||||||
|
for _, value := range values {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
|
||||||
|
return "websocket"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
|
||||||
|
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
|
||||||
|
hasWS := hasSectionPayload(apiWebsocketTimeline)
|
||||||
|
switch {
|
||||||
|
case hasHTTP && hasWS:
|
||||||
|
return "websocket+http"
|
||||||
|
case hasWS:
|
||||||
|
return "websocket"
|
||||||
|
case hasHTTP:
|
||||||
|
return "http"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
|||||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if !bytes.HasSuffix(payload, []byte("\n")) {
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
|
||||||
return errWrite
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
|||||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
|
||||||
return errWrite
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
|
|||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
trailingNewlines := 1
|
||||||
if apiResponseErrors[i].Error != nil {
|
if apiResponseErrors[i].Error != nil {
|
||||||
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
|
errText := apiResponseErrors[i].Error.Error()
|
||||||
|
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if errText != "" {
|
||||||
|
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
var bufferedReader *bufio.Reader
|
||||||
return errWrite
|
if responseReader != nil {
|
||||||
|
bufferedReader = bufio.NewReader(responseReader)
|
||||||
|
}
|
||||||
|
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
|
||||||
|
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseReader != nil {
|
if bufferedReader != nil {
|
||||||
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
|
if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
|
||||||
return errCopy
|
return errCopy
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
|
||||||
|
if reader == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// formatLogContent creates the complete log content for non-streaming requests.
|
// formatLogContent creates the complete log content for non-streaming requests.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
// - method: The HTTP method
|
// - method: The HTTP method
|
||||||
// - headers: The request headers
|
// - headers: The request headers
|
||||||
// - body: The request body
|
// - body: The request body
|
||||||
|
// - websocketTimeline: The downstream websocket event timeline
|
||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
// - response: The raw response data
|
// - response: The raw response data
|
||||||
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: The formatted log content
|
// - string: The formatted log content
|
||||||
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
||||||
var content strings.Builder
|
var content strings.Builder
|
||||||
|
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||||
|
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
|
||||||
|
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||||
|
|
||||||
// Request info
|
// Request info
|
||||||
content.WriteString(l.formatRequestInfo(url, method, headers, body))
|
content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
|
||||||
|
|
||||||
|
if len(websocketTimeline) > 0 {
|
||||||
|
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
|
||||||
|
content.Write(websocketTimeline)
|
||||||
|
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
|
||||||
|
content.Write(websocketTimeline)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(apiWebsocketTimeline) > 0 {
|
||||||
|
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
|
||||||
|
content.Write(apiWebsocketTimeline)
|
||||||
|
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
|
||||||
|
content.Write(apiWebsocketTimeline)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
if len(apiRequest) > 0 {
|
if len(apiRequest) > 0 {
|
||||||
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
|
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
|
||||||
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
|
|||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isWebsocketTranscript {
|
||||||
|
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
|
||||||
|
// timeline sections instead of a generic downstream HTTP response block.
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
// Response section
|
// Response section
|
||||||
content.WriteString("=== RESPONSE ===\n")
|
content.WriteString("=== RESPONSE ===\n")
|
||||||
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||||
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: The formatted request information
|
// - string: The formatted request information
|
||||||
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
|
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
|
||||||
var content strings.Builder
|
var content strings.Builder
|
||||||
|
|
||||||
content.WriteString("=== REQUEST INFO ===\n")
|
content.WriteString("=== REQUEST INFO ===\n")
|
||||||
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
||||||
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||||
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
||||||
|
if strings.TrimSpace(downstreamTransport) != "" {
|
||||||
|
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(upstreamTransport) != "" {
|
||||||
|
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
|
||||||
|
}
|
||||||
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
|
|
||||||
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
|||||||
}
|
}
|
||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
|
|
||||||
|
if !includeBody {
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
content.WriteString("=== REQUEST BODY ===\n")
|
content.WriteString("=== REQUEST BODY ===\n")
|
||||||
content.Write(body)
|
content.Write(body)
|
||||||
content.WriteString("\n\n")
|
content.WriteString("\n\n")
|
||||||
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
|
|||||||
// apiResponse stores the upstream API response data.
|
// apiResponse stores the upstream API response data.
|
||||||
apiResponse []byte
|
apiResponse []byte
|
||||||
|
|
||||||
|
// apiWebsocketTimeline stores the upstream websocket event timeline.
|
||||||
|
apiWebsocketTimeline []byte
|
||||||
|
|
||||||
// apiResponseTimestamp captures when the API response was received.
|
// apiResponseTimestamp captures when the API response was received.
|
||||||
apiResponseTimestamp time.Time
|
apiResponseTimestamp time.Time
|
||||||
}
|
}
|
||||||
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil (buffering cannot fail)
|
||||||
|
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||||
|
if len(apiWebsocketTimeline) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||||
if !timestamp.IsZero() {
|
if !timestamp.IsZero() {
|
||||||
w.apiResponseTimestamp = timestamp
|
w.apiResponseTimestamp = timestamp
|
||||||
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
|||||||
|
|
||||||
// Close finalizes the log file and cleans up resources.
|
// Close finalizes the log file and cleans up resources.
|
||||||
// It writes all buffered data to the file in the correct order:
|
// It writes all buffered data to the file in the correct order:
|
||||||
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
// API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if closing fails, nil otherwise
|
// - error: An error if closing fails, nil otherwise
|
||||||
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
||||||
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil
|
||||||
|
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
||||||
|
|
||||||
// Close is a no-op implementation that does nothing and always returns nil.
|
// Close is a no-op implementation that does nothing and always returns nil.
|
||||||
|
|||||||
151
internal/misc/antigravity_version.go
Normal file
151
internal/misc/antigravity_version.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
// Package misc provides miscellaneous utility functions for the CLI Proxy API server.
|
||||||
|
package misc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravityReleasesURL = "https://antigravity-auto-updater-974169037036.us-central1.run.app/releases"
|
||||||
|
antigravityFallbackVersion = "1.21.9"
|
||||||
|
antigravityVersionCacheTTL = 6 * time.Hour
|
||||||
|
antigravityFetchTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type antigravityRelease struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
ExecutionID string `json:"execution_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
cachedAntigravityVersion = antigravityFallbackVersion
|
||||||
|
antigravityVersionMu sync.RWMutex
|
||||||
|
antigravityVersionExpiry time.Time
|
||||||
|
antigravityUpdaterOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartAntigravityVersionUpdater starts a background goroutine that periodically refreshes the cached antigravity version.
|
||||||
|
// This is intentionally decoupled from request execution to avoid blocking executors on version lookups.
|
||||||
|
func StartAntigravityVersionUpdater(ctx context.Context) {
|
||||||
|
antigravityUpdaterOnce.Do(func() {
|
||||||
|
go runAntigravityVersionUpdater(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runAntigravityVersionUpdater(ctx context.Context) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(antigravityVersionCacheTTL / 2)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
log.Infof("periodic antigravity version refresh started (interval=%s)", antigravityVersionCacheTTL/2)
|
||||||
|
|
||||||
|
refreshAntigravityVersion(ctx)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
refreshAntigravityVersion(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func refreshAntigravityVersion(ctx context.Context) {
|
||||||
|
version, errFetch := fetchAntigravityLatestVersion(ctx)
|
||||||
|
|
||||||
|
antigravityVersionMu.Lock()
|
||||||
|
defer antigravityVersionMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if errFetch == nil {
|
||||||
|
cachedAntigravityVersion = version
|
||||||
|
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
|
||||||
|
log.WithField("version", version).Info("fetched latest antigravity version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if cachedAntigravityVersion == "" || now.After(antigravityVersionExpiry) {
|
||||||
|
cachedAntigravityVersion = antigravityFallbackVersion
|
||||||
|
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
|
||||||
|
log.WithError(errFetch).Warn("failed to refresh antigravity version, using fallback version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithError(errFetch).Debug("failed to refresh antigravity version, keeping cached value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityLatestVersion returns the cached antigravity version refreshed by StartAntigravityVersionUpdater.
|
||||||
|
// It falls back to antigravityFallbackVersion if the cache is empty or stale.
|
||||||
|
func AntigravityLatestVersion() string {
|
||||||
|
antigravityVersionMu.RLock()
|
||||||
|
if cachedAntigravityVersion != "" && time.Now().Before(antigravityVersionExpiry) {
|
||||||
|
v := cachedAntigravityVersion
|
||||||
|
antigravityVersionMu.RUnlock()
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
antigravityVersionMu.RUnlock()
|
||||||
|
|
||||||
|
return antigravityFallbackVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityUserAgent returns the User-Agent string for antigravity requests
|
||||||
|
// using the latest version fetched from the releases API.
|
||||||
|
func AntigravityUserAgent() string {
|
||||||
|
return fmt.Sprintf("antigravity/%s darwin/arm64", AntigravityLatestVersion())
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchAntigravityLatestVersion(ctx context.Context) (string, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: antigravityFetchTimeout}
|
||||||
|
|
||||||
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodGet, antigravityReleasesURL, nil)
|
||||||
|
if errReq != nil {
|
||||||
|
return "", fmt.Errorf("build antigravity releases request: %w", errReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, errDo := client.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("fetch antigravity releases: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.WithError(errClose).Warn("antigravity releases response body close error")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("antigravity releases API returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var releases []antigravityRelease
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&releases); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode antigravity releases response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(releases) == 0 {
|
||||||
|
return "", errors.New("antigravity releases API returned empty list")
|
||||||
|
}
|
||||||
|
|
||||||
|
version := releases[0].Version
|
||||||
|
if version == "" {
|
||||||
|
return "", errors.New("antigravity releases API returned empty version")
|
||||||
|
}
|
||||||
|
|
||||||
|
return version, nil
|
||||||
|
}
|
||||||
@@ -93,6 +93,54 @@ func GetAntigravityModels() []*ModelInfo {
|
|||||||
func GetCodeBuddyModels() []*ModelInfo {
|
func GetCodeBuddyModels() []*ModelInfo {
|
||||||
now := int64(1748044800) // 2025-05-24
|
now := int64(1748044800) // 2025-05-24
|
||||||
return []*ModelInfo{
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "auto",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Auto",
|
||||||
|
Description: "Automatic model selection via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5v-turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5v Turbo",
|
||||||
|
Description: "GLM-5v Turbo via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5.1",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.1",
|
||||||
|
Description: "GLM-5.1 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5.0-turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.0 Turbo",
|
||||||
|
Description: "GLM-5.0 Turbo via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "glm-5.0",
|
ID: "glm-5.0",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -101,7 +149,7 @@ func GetCodeBuddyModels() []*ModelInfo {
|
|||||||
Type: "codebuddy",
|
Type: "codebuddy",
|
||||||
DisplayName: "GLM-5.0",
|
DisplayName: "GLM-5.0",
|
||||||
Description: "GLM-5.0 via CodeBuddy",
|
Description: "GLM-5.0 via CodeBuddy",
|
||||||
ContextLength: 128000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 32768,
|
MaxCompletionTokens: 32768,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
@@ -113,18 +161,18 @@ func GetCodeBuddyModels() []*ModelInfo {
|
|||||||
Type: "codebuddy",
|
Type: "codebuddy",
|
||||||
DisplayName: "GLM-4.7",
|
DisplayName: "GLM-4.7",
|
||||||
Description: "GLM-4.7 via CodeBuddy",
|
Description: "GLM-4.7 via CodeBuddy",
|
||||||
ContextLength: 128000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 32768,
|
MaxCompletionTokens: 32768,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "minimax-m2.5",
|
ID: "minimax-m2.7",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: now,
|
Created: now,
|
||||||
OwnedBy: "tencent",
|
OwnedBy: "tencent",
|
||||||
Type: "codebuddy",
|
Type: "codebuddy",
|
||||||
DisplayName: "MiniMax M2.5",
|
DisplayName: "MiniMax M2.7",
|
||||||
Description: "MiniMax M2.5 via CodeBuddy",
|
Description: "MiniMax M2.7 via CodeBuddy",
|
||||||
ContextLength: 200000,
|
ContextLength: 200000,
|
||||||
MaxCompletionTokens: 32768,
|
MaxCompletionTokens: 32768,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
@@ -137,10 +185,23 @@ func GetCodeBuddyModels() []*ModelInfo {
|
|||||||
Type: "codebuddy",
|
Type: "codebuddy",
|
||||||
DisplayName: "Kimi K2.5",
|
DisplayName: "Kimi K2.5",
|
||||||
Description: "Kimi K2.5 via CodeBuddy",
|
Description: "Kimi K2.5 via CodeBuddy",
|
||||||
ContextLength: 128000,
|
ContextLength: 256000,
|
||||||
MaxCompletionTokens: 32768,
|
MaxCompletionTokens: 32768,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "kimi-k2-thinking",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Kimi K2 Thinking",
|
||||||
|
Description: "Kimi K2 Thinking via CodeBuddy",
|
||||||
|
ContextLength: 256000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "deepseek-v3-2-volc",
|
ID: "deepseek-v3-2-volc",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -148,24 +209,11 @@ func GetCodeBuddyModels() []*ModelInfo {
|
|||||||
OwnedBy: "tencent",
|
OwnedBy: "tencent",
|
||||||
Type: "codebuddy",
|
Type: "codebuddy",
|
||||||
DisplayName: "DeepSeek V3.2 (Volc)",
|
DisplayName: "DeepSeek V3.2 (Volc)",
|
||||||
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)",
|
Description: "DeepSeek V3.2 via CodeBuddy",
|
||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 32768,
|
MaxCompletionTokens: 32768,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
ID: "hunyuan-2.0-thinking",
|
|
||||||
Object: "model",
|
|
||||||
Created: now,
|
|
||||||
OwnedBy: "tencent",
|
|
||||||
Type: "codebuddy",
|
|
||||||
DisplayName: "Hunyuan 2.0 Thinking",
|
|
||||||
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
|
|
||||||
ContextLength: 128000,
|
|
||||||
MaxCompletionTokens: 32768,
|
|
||||||
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,6 +335,13 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// defaultCopilotClaudeContextLength is the conservative prompt token limit for
|
||||||
|
// Claude models accessed via the GitHub Copilot API. Individual accounts are
|
||||||
|
// capped at 128K; business accounts at 168K. When the dynamic /models API fetch
|
||||||
|
// succeeds, the real per-account limit overrides this value. This constant is
|
||||||
|
// only used as a safe fallback.
|
||||||
|
const defaultCopilotClaudeContextLength = 128000
|
||||||
|
|
||||||
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
||||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||||
func GetGitHubCopilotModels() []*ModelInfo {
|
func GetGitHubCopilotModels() []*ModelInfo {
|
||||||
@@ -477,6 +532,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
SupportedEndpoints: []string{"/responses"},
|
SupportedEndpoints: []string{"/responses"},
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.4-mini",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.4 mini",
|
||||||
|
Description: "OpenAI GPT-5.4 mini via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/responses"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-haiku-4.5",
|
ID: "claude-haiku-4.5",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -485,7 +553,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Haiku 4.5",
|
DisplayName: "Claude Haiku 4.5",
|
||||||
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
@@ -497,7 +565,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.1",
|
DisplayName: "Claude Opus 4.1",
|
||||||
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 32000,
|
MaxCompletionTokens: 32000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
@@ -509,9 +577,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.5",
|
DisplayName: "Claude Opus 4.5",
|
||||||
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4.6",
|
ID: "claude-opus-4.6",
|
||||||
@@ -521,9 +590,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.6",
|
DisplayName: "Claude Opus 4.6",
|
||||||
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4",
|
ID: "claude-sonnet-4",
|
||||||
@@ -533,9 +603,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4",
|
DisplayName: "Claude Sonnet 4",
|
||||||
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4.5",
|
ID: "claude-sonnet-4.5",
|
||||||
@@ -545,9 +616,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4.5",
|
DisplayName: "Claude Sonnet 4.5",
|
||||||
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4.6",
|
ID: "claude-sonnet-4.6",
|
||||||
@@ -557,9 +629,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4.6",
|
DisplayName: "Claude Sonnet 4.6",
|
||||||
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-2.5-pro",
|
ID: "gemini-2.5-pro",
|
||||||
@@ -571,6 +644,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
|
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 1048576,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-pro-preview",
|
ID: "gemini-3-pro-preview",
|
||||||
@@ -582,6 +656,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
|
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 1048576,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3.1-pro-preview",
|
ID: "gemini-3.1-pro-preview",
|
||||||
@@ -591,8 +666,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Gemini 3.1 Pro (Preview)",
|
DisplayName: "Gemini 3.1 Pro (Preview)",
|
||||||
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 173000,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
@@ -602,8 +678,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Gemini 3 Flash (Preview)",
|
DisplayName: "Gemini 3 Flash (Preview)",
|
||||||
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
|
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 173000,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "grok-code-fast-1",
|
ID: "grok-code-fast-1",
|
||||||
|
|||||||
29
internal/registry/model_definitions_test.go
Normal file
29
internal/registry/model_definitions_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGitHubCopilotGeminiModelsAreChatOnly(t *testing.T) {
|
||||||
|
models := GetGitHubCopilotModels()
|
||||||
|
required := map[string]bool{
|
||||||
|
"gemini-2.5-pro": false,
|
||||||
|
"gemini-3-pro-preview": false,
|
||||||
|
"gemini-3.1-pro-preview": false,
|
||||||
|
"gemini-3-flash-preview": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
if _, ok := required[model.ID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
required[model.ID] = true
|
||||||
|
if len(model.SupportedEndpoints) != 1 || model.SupportedEndpoints[0] != "/chat/completions" {
|
||||||
|
t.Fatalf("model %q supported endpoints = %v, want [/chat/completions]", model.ID, model.SupportedEndpoints)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelID, found := range required {
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected GitHub Copilot model %q in definitions", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1177,6 +1177,16 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
"dynamic_allowed": model.Thinking.DynamicAllowed,
|
"dynamic_allowed": model.Thinking.DynamicAllowed,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Include context limits so Claude Code can manage conversation
|
||||||
|
// context correctly, especially for Copilot-proxied models whose
|
||||||
|
// real prompt limit (128K-168K) is much lower than the 1M window
|
||||||
|
// that Claude Code may assume for Opus 4.6 with 1M context enabled.
|
||||||
|
if model.ContextLength > 0 {
|
||||||
|
result["context_length"] = model.ContextLength
|
||||||
|
}
|
||||||
|
if model.MaxCompletionTokens > 0 {
|
||||||
|
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||||
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
case "gemini":
|
case "gemini":
|
||||||
|
|||||||
@@ -280,6 +280,7 @@
|
|||||||
"dynamic_allowed": true,
|
"dynamic_allowed": true,
|
||||||
"levels": [
|
"levels": [
|
||||||
"low",
|
"low",
|
||||||
|
"medium",
|
||||||
"high"
|
"high"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -554,6 +555,7 @@
|
|||||||
"dynamic_allowed": true,
|
"dynamic_allowed": true,
|
||||||
"levels": [
|
"levels": [
|
||||||
"low",
|
"low",
|
||||||
|
"medium",
|
||||||
"high"
|
"high"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -610,6 +612,8 @@
|
|||||||
"dynamic_allowed": true,
|
"dynamic_allowed": true,
|
||||||
"levels": [
|
"levels": [
|
||||||
"minimal",
|
"minimal",
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
"high"
|
"high"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -838,6 +842,7 @@
|
|||||||
"dynamic_allowed": true,
|
"dynamic_allowed": true,
|
||||||
"levels": [
|
"levels": [
|
||||||
"low",
|
"low",
|
||||||
|
"medium",
|
||||||
"high"
|
"high"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -896,6 +901,8 @@
|
|||||||
"dynamic_allowed": true,
|
"dynamic_allowed": true,
|
||||||
"levels": [
|
"levels": [
|
||||||
"minimal",
|
"minimal",
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
"high"
|
"high"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -1070,6 +1077,8 @@
|
|||||||
"dynamic_allowed": true,
|
"dynamic_allowed": true,
|
||||||
"levels": [
|
"levels": [
|
||||||
"minimal",
|
"minimal",
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
"high"
|
"high"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -1371,6 +1380,75 @@
|
|||||||
"xhigh"
|
"xhigh"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-5.3-codex",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1770307200,
|
||||||
|
"owned_by": "openai",
|
||||||
|
"type": "openai",
|
||||||
|
"display_name": "GPT 5.3 Codex",
|
||||||
|
"version": "gpt-5.3",
|
||||||
|
"description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
|
"context_length": 400000,
|
||||||
|
"max_completion_tokens": 128000,
|
||||||
|
"supported_parameters": [
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high",
|
||||||
|
"xhigh"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-5.4",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1772668800,
|
||||||
|
"owned_by": "openai",
|
||||||
|
"type": "openai",
|
||||||
|
"display_name": "GPT 5.4",
|
||||||
|
"version": "gpt-5.4",
|
||||||
|
"description": "Stable version of GPT 5.4",
|
||||||
|
"context_length": 1050000,
|
||||||
|
"max_completion_tokens": 128000,
|
||||||
|
"supported_parameters": [
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high",
|
||||||
|
"xhigh"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-5.4-mini",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1773705600,
|
||||||
|
"owned_by": "openai",
|
||||||
|
"type": "openai",
|
||||||
|
"display_name": "GPT 5.4 Mini",
|
||||||
|
"version": "gpt-5.4-mini",
|
||||||
|
"description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.",
|
||||||
|
"context_length": 400000,
|
||||||
|
"max_completion_tokens": 128000,
|
||||||
|
"supported_parameters": [
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high",
|
||||||
|
"xhigh"
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"codex-team": [
|
"codex-team": [
|
||||||
@@ -1623,6 +1701,29 @@
|
|||||||
"xhigh"
|
"xhigh"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-5.4-mini",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1773705600,
|
||||||
|
"owned_by": "openai",
|
||||||
|
"type": "openai",
|
||||||
|
"display_name": "GPT 5.4 Mini",
|
||||||
|
"version": "gpt-5.4-mini",
|
||||||
|
"description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.",
|
||||||
|
"context_length": 400000,
|
||||||
|
"max_completion_tokens": 128000,
|
||||||
|
"supported_parameters": [
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high",
|
||||||
|
"xhigh"
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"codex-plus": [
|
"codex-plus": [
|
||||||
@@ -1898,6 +1999,29 @@
|
|||||||
"xhigh"
|
"xhigh"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-5.4-mini",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1773705600,
|
||||||
|
"owned_by": "openai",
|
||||||
|
"type": "openai",
|
||||||
|
"display_name": "GPT 5.4 Mini",
|
||||||
|
"version": "gpt-5.4-mini",
|
||||||
|
"description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.",
|
||||||
|
"context_length": 400000,
|
||||||
|
"max_completion_tokens": 128000,
|
||||||
|
"supported_parameters": [
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high",
|
||||||
|
"xhigh"
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"codex-pro": [
|
"codex-pro": [
|
||||||
@@ -2173,55 +2297,40 @@
|
|||||||
"xhigh"
|
"xhigh"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-5.4-mini",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1773705600,
|
||||||
|
"owned_by": "openai",
|
||||||
|
"type": "openai",
|
||||||
|
"display_name": "GPT 5.4 Mini",
|
||||||
|
"version": "gpt-5.4-mini",
|
||||||
|
"description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.",
|
||||||
|
"context_length": 400000,
|
||||||
|
"max_completion_tokens": 128000,
|
||||||
|
"supported_parameters": [
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high",
|
||||||
|
"xhigh"
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"qwen": [
|
"qwen": [
|
||||||
{
|
|
||||||
"id": "qwen3-coder-plus",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1753228800,
|
|
||||||
"owned_by": "qwen",
|
|
||||||
"type": "qwen",
|
|
||||||
"display_name": "Qwen3 Coder Plus",
|
|
||||||
"version": "3.0",
|
|
||||||
"description": "Advanced code generation and understanding model",
|
|
||||||
"context_length": 32768,
|
|
||||||
"max_completion_tokens": 8192,
|
|
||||||
"supported_parameters": [
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"max_tokens",
|
|
||||||
"stream",
|
|
||||||
"stop"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "qwen3-coder-flash",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1753228800,
|
|
||||||
"owned_by": "qwen",
|
|
||||||
"type": "qwen",
|
|
||||||
"display_name": "Qwen3 Coder Flash",
|
|
||||||
"version": "3.0",
|
|
||||||
"description": "Fast code generation model",
|
|
||||||
"context_length": 8192,
|
|
||||||
"max_completion_tokens": 2048,
|
|
||||||
"supported_parameters": [
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"max_tokens",
|
|
||||||
"stream",
|
|
||||||
"stop"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": "coder-model",
|
"id": "coder-model",
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": 1771171200,
|
"created": 1771171200,
|
||||||
"owned_by": "qwen",
|
"owned_by": "qwen",
|
||||||
"type": "qwen",
|
"type": "qwen",
|
||||||
"display_name": "Qwen 3.5 Plus",
|
"display_name": "Qwen 3.6 Plus",
|
||||||
"version": "3.5",
|
"version": "3.6",
|
||||||
"description": "efficient hybrid model with leading coding performance",
|
"description": "efficient hybrid model with leading coding performance",
|
||||||
"context_length": 1048576,
|
"context_length": 1048576,
|
||||||
"max_completion_tokens": 65536,
|
"max_completion_tokens": 65536,
|
||||||
@@ -2232,25 +2341,6 @@
|
|||||||
"stream",
|
"stream",
|
||||||
"stop"
|
"stop"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "vision-model",
|
|
||||||
"object": "model",
|
|
||||||
"created": 1758672000,
|
|
||||||
"owned_by": "qwen",
|
|
||||||
"type": "qwen",
|
|
||||||
"display_name": "Qwen3 Vision Model",
|
|
||||||
"version": "3.0",
|
|
||||||
"description": "Vision model model",
|
|
||||||
"context_length": 32768,
|
|
||||||
"max_completion_tokens": 2048,
|
|
||||||
"supported_parameters": [
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"max_tokens",
|
|
||||||
"stream",
|
|
||||||
"stop"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"iflow": [
|
"iflow": [
|
||||||
@@ -2639,11 +2729,12 @@
|
|||||||
"context_length": 1048576,
|
"context_length": 1048576,
|
||||||
"max_completion_tokens": 65535,
|
"max_completion_tokens": 65535,
|
||||||
"thinking": {
|
"thinking": {
|
||||||
"min": 128,
|
"min": 1,
|
||||||
"max": 32768,
|
"max": 65535,
|
||||||
"dynamic_allowed": true,
|
"dynamic_allowed": true,
|
||||||
"levels": [
|
"levels": [
|
||||||
"low",
|
"low",
|
||||||
|
"medium",
|
||||||
"high"
|
"high"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -2659,11 +2750,12 @@
|
|||||||
"context_length": 1048576,
|
"context_length": 1048576,
|
||||||
"max_completion_tokens": 65535,
|
"max_completion_tokens": 65535,
|
||||||
"thinking": {
|
"thinking": {
|
||||||
"min": 128,
|
"min": 1,
|
||||||
"max": 32768,
|
"max": 65535,
|
||||||
"dynamic_allowed": true,
|
"dynamic_allowed": true,
|
||||||
"levels": [
|
"levels": [
|
||||||
"low",
|
"low",
|
||||||
|
"medium",
|
||||||
"high"
|
"high"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -46,8 +48,16 @@ func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Man
|
|||||||
// Identifier returns the executor identifier.
|
// Identifier returns the executor identifier.
|
||||||
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
|
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
|
||||||
|
|
||||||
// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio).
|
// PrepareRequest prepares the HTTP request for execution.
|
||||||
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
func (e *AIStudioExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,6 +76,9 @@ func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return nil, fmt.Errorf("aistudio executor: missing auth")
|
return nil, fmt.Errorf("aistudio executor: missing auth")
|
||||||
}
|
}
|
||||||
httpReq := req.WithContext(ctx)
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
|
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
|
||||||
return nil, fmt.Errorf("aistudio executor: request URL is empty")
|
return nil, fmt.Errorf("aistudio executor: request URL is empty")
|
||||||
}
|
}
|
||||||
@@ -115,8 +128,8 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
translatedReq, body, err := e.translateRequest(req, opts, false)
|
translatedReq, body, err := e.translateRequest(req, opts, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -130,6 +143,11 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
Body: body.payload,
|
Body: body.payload,
|
||||||
}
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -137,7 +155,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: wsReq.Headers.Clone(),
|
Headers: wsReq.Headers.Clone(),
|
||||||
@@ -151,17 +169,17 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
|
|
||||||
wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
|
wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||||
if len(wsResp.Body) > 0 {
|
if len(wsResp.Body) > 0 {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
|
||||||
}
|
}
|
||||||
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
||||||
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
||||||
}
|
}
|
||||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
reporter.Publish(ctx, helps.ParseGeminiUsage(wsResp.Body))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
|
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
|
||||||
@@ -174,8 +192,8 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
translatedReq, body, err := e.translateRequest(req, opts, true)
|
translatedReq, body, err := e.translateRequest(req, opts, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -189,13 +207,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
Body: body.payload,
|
Body: body.payload,
|
||||||
}
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: wsReq.Headers}, attrs)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: wsReq.Headers.Clone(),
|
Headers: wsReq.Headers.Clone(),
|
||||||
@@ -208,24 +231,24 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
})
|
})
|
||||||
wsStream, err := e.relay.Stream(ctx, authID, wsReq)
|
wsStream, err := e.relay.Stream(ctx, authID, wsReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
firstEvent, ok := <-wsStream
|
firstEvent, ok := <-wsStream
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("wsrelay: stream closed before start")
|
err = fmt.Errorf("wsrelay: stream closed before start")
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK {
|
if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK {
|
||||||
metadataLogged := false
|
metadataLogged := false
|
||||||
if firstEvent.Status > 0 {
|
if firstEvent.Status > 0 {
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
|
||||||
metadataLogged = true
|
metadataLogged = true
|
||||||
}
|
}
|
||||||
var body bytes.Buffer
|
var body bytes.Buffer
|
||||||
if len(firstEvent.Payload) > 0 {
|
if len(firstEvent.Payload) > 0 {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
|
||||||
body.Write(firstEvent.Payload)
|
body.Write(firstEvent.Payload)
|
||||||
}
|
}
|
||||||
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
|
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
|
||||||
@@ -233,18 +256,18 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
}
|
}
|
||||||
for event := range wsStream {
|
for event := range wsStream {
|
||||||
if event.Err != nil {
|
if event.Err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, event.Err)
|
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
|
||||||
if body.Len() == 0 {
|
if body.Len() == 0 {
|
||||||
body.WriteString(event.Err.Error())
|
body.WriteString(event.Err.Error())
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if !metadataLogged && event.Status > 0 {
|
if !metadataLogged && event.Status > 0 {
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||||
metadataLogged = true
|
metadataLogged = true
|
||||||
}
|
}
|
||||||
if len(event.Payload) > 0 {
|
if len(event.Payload) > 0 {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||||
body.Write(event.Payload)
|
body.Write(event.Payload)
|
||||||
}
|
}
|
||||||
if event.Type == wsrelay.MessageTypeStreamEnd {
|
if event.Type == wsrelay.MessageTypeStreamEnd {
|
||||||
@@ -260,23 +283,23 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
metadataLogged := false
|
metadataLogged := false
|
||||||
processEvent := func(event wsrelay.StreamEvent) bool {
|
processEvent := func(event wsrelay.StreamEvent) bool {
|
||||||
if event.Err != nil {
|
if event.Err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, event.Err)
|
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
case wsrelay.MessageTypeStreamStart:
|
case wsrelay.MessageTypeStreamStart:
|
||||||
if !metadataLogged && event.Status > 0 {
|
if !metadataLogged && event.Status > 0 {
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||||
metadataLogged = true
|
metadataLogged = true
|
||||||
}
|
}
|
||||||
case wsrelay.MessageTypeStreamChunk:
|
case wsrelay.MessageTypeStreamChunk:
|
||||||
if len(event.Payload) > 0 {
|
if len(event.Payload) > 0 {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||||
filtered := FilterSSEUsageMetadata(event.Payload)
|
filtered := helps.FilterSSEUsageMetadata(event.Payload)
|
||||||
if detail, ok := parseGeminiStreamUsage(filtered); ok {
|
if detail, ok := helps.ParseGeminiStreamUsage(filtered); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
@@ -288,21 +311,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
return false
|
return false
|
||||||
case wsrelay.MessageTypeHTTPResp:
|
case wsrelay.MessageTypeHTTPResp:
|
||||||
if !metadataLogged && event.Status > 0 {
|
if !metadataLogged && event.Status > 0 {
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||||
metadataLogged = true
|
metadataLogged = true
|
||||||
}
|
}
|
||||||
if len(event.Payload) > 0 {
|
if len(event.Payload) > 0 {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||||
}
|
}
|
||||||
reporter.publish(ctx, parseGeminiUsage(event.Payload))
|
reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
|
||||||
return false
|
return false
|
||||||
case wsrelay.MessageTypeError:
|
case wsrelay.MessageTypeError:
|
||||||
recordAPIResponseError(ctx, e.cfg, event.Err)
|
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -345,7 +368,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: wsReq.Headers.Clone(),
|
Headers: wsReq.Headers.Clone(),
|
||||||
@@ -358,12 +381,12 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
|||||||
})
|
})
|
||||||
resp, err := e.relay.NonStream(ctx, authID, wsReq)
|
resp, err := e.relay.NonStream(ctx, authID, wsReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||||
if len(resp.Body) > 0 {
|
if len(resp.Body) > 0 {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, resp.Body)
|
||||||
}
|
}
|
||||||
if resp.Status < 200 || resp.Status >= 300 {
|
if resp.Status < 200 || resp.Status >= 300 {
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
||||||
@@ -404,8 +427,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
|||||||
return nil, translatedPayload{}, err
|
return nil, translatedPayload{}, err
|
||||||
}
|
}
|
||||||
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import (
|
|||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
@@ -44,15 +46,44 @@ const (
|
|||||||
antigravityGeneratePath = "/v1internal:generateContent"
|
antigravityGeneratePath = "/v1internal:generateContent"
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
antigravityClientSecret = "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"
|
antigravityClientSecret = "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"
|
||||||
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
|
defaultAntigravityAgent = "antigravity/1.21.9 darwin/arm64" // fallback only; overridden at runtime by misc.AntigravityUserAgent()
|
||||||
antigravityAuthType = "antigravity"
|
antigravityAuthType = "antigravity"
|
||||||
refreshSkew = 3000 * time.Second
|
refreshSkew = 3000 * time.Second
|
||||||
|
antigravityCreditsRetryTTL = 5 * time.Hour
|
||||||
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type antigravity429Category string
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravity429Unknown antigravity429Category = "unknown"
|
||||||
|
antigravity429RateLimited antigravity429Category = "rate_limited"
|
||||||
|
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
randSourceMutex sync.Mutex
|
randSourceMutex sync.Mutex
|
||||||
|
antigravityCreditsExhaustedByAuth sync.Map
|
||||||
|
antigravityPreferCreditsByModel sync.Map
|
||||||
|
antigravityQuotaExhaustedKeywords = []string{
|
||||||
|
"quota_exhausted",
|
||||||
|
"quota exhausted",
|
||||||
|
}
|
||||||
|
antigravityCreditsExhaustedKeywords = []string{
|
||||||
|
"google_one_ai",
|
||||||
|
"insufficient credit",
|
||||||
|
"insufficient credits",
|
||||||
|
"not enough credit",
|
||||||
|
"not enough credits",
|
||||||
|
"credit exhausted",
|
||||||
|
"credits exhausted",
|
||||||
|
"credit balance",
|
||||||
|
"minimumcreditamountforusage",
|
||||||
|
"minimum credit amount for usage",
|
||||||
|
"minimum credit",
|
||||||
|
"resource has been exhausted",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||||
@@ -113,7 +144,7 @@ func initAntigravityTransport() {
|
|||||||
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||||
antigravityTransportOnce.Do(initAntigravityTransport)
|
antigravityTransportOnce.Do(initAntigravityTransport)
|
||||||
|
|
||||||
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
client := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||||
// If no transport is set, use the shared HTTP/1.1 transport.
|
// If no transport is set, use the shared HTTP/1.1 transport.
|
||||||
if client.Transport == nil {
|
if client.Transport == nil {
|
||||||
client.Transport = antigravityTransport
|
client.Transport = antigravityTransport
|
||||||
@@ -183,6 +214,259 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
|
|||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func injectEnabledCreditTypes(payload []byte) []byte {
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(payload) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
updated, err := sjson.SetRawBytes(payload, "enabledCreditTypes", []byte(`["GOOGLE_ONE_AI"]`))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyAntigravity429(body []byte) antigravity429Category {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return antigravity429Unknown
|
||||||
|
}
|
||||||
|
lowerBody := strings.ToLower(string(body))
|
||||||
|
for _, keyword := range antigravityQuotaExhaustedKeywords {
|
||||||
|
if strings.Contains(lowerBody, keyword) {
|
||||||
|
return antigravity429QuotaExhausted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String())
|
||||||
|
if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") {
|
||||||
|
return antigravity429Unknown
|
||||||
|
}
|
||||||
|
details := gjson.GetBytes(body, "error.details")
|
||||||
|
if !details.Exists() || !details.IsArray() {
|
||||||
|
return antigravity429Unknown
|
||||||
|
}
|
||||||
|
for _, detail := range details.Array() {
|
||||||
|
if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reason := strings.TrimSpace(detail.Get("reason").String())
|
||||||
|
if strings.EqualFold(reason, "QUOTA_EXHAUSTED") {
|
||||||
|
return antigravity429QuotaExhausted
|
||||||
|
}
|
||||||
|
if strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED") {
|
||||||
|
return antigravity429RateLimited
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return antigravity429Unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityHasQuotaResetDelayOrModelInfo(body []byte) bool {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
details := gjson.GetBytes(body, "error.details")
|
||||||
|
if !details.Exists() || !details.IsArray() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, detail := range details.Array() {
|
||||||
|
if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(detail.Get("metadata.quotaResetDelay").String()) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(detail.Get("metadata.model").String()) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityCreditsRetryEnabled(cfg *config.Config) bool {
|
||||||
|
return cfg != nil && cfg.QuotaExceeded.AntigravityCredits
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) bool {
|
||||||
|
if auth == nil || strings.TrimSpace(auth.ID) == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
value, ok := antigravityCreditsExhaustedByAuth.Load(auth.ID)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
until, ok := value.(time.Time)
|
||||||
|
if !ok || until.IsZero() {
|
||||||
|
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !until.After(now) {
|
||||||
|
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func markAntigravityCreditsExhausted(auth *cliproxyauth.Auth, now time.Time) {
|
||||||
|
if auth == nil || strings.TrimSpace(auth.ID) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
antigravityCreditsExhaustedByAuth.Store(auth.ID, now.Add(antigravityCreditsRetryTTL))
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearAntigravityCreditsExhausted(auth *cliproxyauth.Auth) {
|
||||||
|
if auth == nil || strings.TrimSpace(auth.ID) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
antigravityCreditsExhaustedByAuth.Delete(auth.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityPreferCreditsKey(auth *cliproxyauth.Auth, modelName string) string {
|
||||||
|
if auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
authID := strings.TrimSpace(auth.ID)
|
||||||
|
modelName = strings.TrimSpace(modelName)
|
||||||
|
if authID == "" || modelName == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return authID + "|" + modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityShouldPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time) bool {
|
||||||
|
key := antigravityPreferCreditsKey(auth, modelName)
|
||||||
|
if key == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
value, ok := antigravityPreferCreditsByModel.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
until, ok := value.(time.Time)
|
||||||
|
if !ok || until.IsZero() {
|
||||||
|
antigravityPreferCreditsByModel.Delete(key)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !until.After(now) {
|
||||||
|
antigravityPreferCreditsByModel.Delete(key)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func markAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time, retryAfter *time.Duration) {
|
||||||
|
key := antigravityPreferCreditsKey(auth, modelName)
|
||||||
|
if key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
until := now.Add(antigravityCreditsRetryTTL)
|
||||||
|
if retryAfter != nil && *retryAfter > 0 {
|
||||||
|
until = now.Add(*retryAfter)
|
||||||
|
}
|
||||||
|
antigravityPreferCreditsByModel.Store(key, until)
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string) {
|
||||||
|
key := antigravityPreferCreditsKey(auth, modelName)
|
||||||
|
if key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
antigravityPreferCreditsByModel.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldMarkAntigravityCreditsExhausted(statusCode int, body []byte, reqErr error) bool {
|
||||||
|
if reqErr != nil || statusCode == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode >= http.StatusInternalServerError || statusCode == http.StatusRequestTimeout {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
lowerBody := strings.ToLower(string(body))
|
||||||
|
for _, keyword := range antigravityCreditsExhaustedKeywords {
|
||||||
|
if strings.Contains(lowerBody, keyword) {
|
||||||
|
if keyword == "resource has been exhausted" &&
|
||||||
|
statusCode == http.StatusTooManyRequests &&
|
||||||
|
classifyAntigravity429(body) == antigravity429Unknown &&
|
||||||
|
!antigravityHasQuotaResetDelayOrModelInfo(body) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAntigravityStatusErr(statusCode int, body []byte) statusErr {
|
||||||
|
err := statusErr{code: statusCode, msg: string(body)}
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil {
|
||||||
|
err.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AntigravityExecutor) attemptCreditsFallback(
|
||||||
|
ctx context.Context,
|
||||||
|
auth *cliproxyauth.Auth,
|
||||||
|
httpClient *http.Client,
|
||||||
|
token string,
|
||||||
|
modelName string,
|
||||||
|
payload []byte,
|
||||||
|
stream bool,
|
||||||
|
alt string,
|
||||||
|
baseURL string,
|
||||||
|
originalBody []byte,
|
||||||
|
) (*http.Response, bool) {
|
||||||
|
if !antigravityCreditsRetryEnabled(e.cfg) {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if classifyAntigravity429(originalBody) != antigravity429QuotaExhausted {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if antigravityCreditsExhausted(auth, now) {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
creditsPayload := injectEnabledCreditTypes(payload)
|
||||||
|
if len(creditsPayload) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, errReq := e.buildRequest(ctx, auth, token, modelName, creditsPayload, stream, alt, baseURL)
|
||||||
|
if errReq != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, errReq)
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||||
|
retryAfter, _ := parseRetryDelay(originalBody)
|
||||||
|
markAntigravityPreferCredits(auth, modelName, now, retryAfter)
|
||||||
|
clearAntigravityCreditsExhausted(auth)
|
||||||
|
return httpResp, true
|
||||||
|
}
|
||||||
|
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close credits fallback response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
if errRead != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
|
||||||
|
clearAntigravityPreferCredits(auth, modelName)
|
||||||
|
markAntigravityCreditsExhausted(auth, now)
|
||||||
|
}
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
// Execute performs a non-streaming request to the Antigravity API.
|
// Execute performs a non-streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
@@ -203,8 +487,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
auth = updatedAuth
|
auth = updatedAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
@@ -222,8 +506,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -237,7 +521,15 @@ attemptLoop:
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
|
requestPayload := translated
|
||||||
|
usedCreditsDirect := false
|
||||||
|
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
|
||||||
|
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
|
||||||
|
requestPayload = creditsPayload
|
||||||
|
usedCreditsDirect = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, false, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -245,7 +537,7 @@ attemptLoop:
|
|||||||
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
return resp, errDo
|
return resp, errDo
|
||||||
}
|
}
|
||||||
@@ -260,20 +552,50 @@ attemptLoop:
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
err = errRead
|
err = errRead
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if usedCreditsDirect {
|
||||||
|
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
|
||||||
|
clearAntigravityPreferCredits(auth, baseModel)
|
||||||
|
markAntigravityCreditsExhausted(auth, time.Now())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, false, opts.Alt, baseURL, bodyBytes)
|
||||||
|
if creditsResp != nil {
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, creditsResp.StatusCode, creditsResp.Header.Clone())
|
||||||
|
creditsBody, errCreditsRead := io.ReadAll(creditsResp.Body)
|
||||||
|
if errClose := creditsResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close credits success response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
if errCreditsRead != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, errCreditsRead)
|
||||||
|
err = errCreditsRead
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
helps.AppendAPIResponseChunk(ctx, e.cfg, creditsBody)
|
||||||
|
reporter.Publish(ctx, helps.ParseAntigravityUsage(creditsBody))
|
||||||
|
var param any
|
||||||
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, creditsBody, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: converted, Headers: creditsResp.Header.Clone()}
|
||||||
|
reporter.EnsurePublished(ctx)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
lastErr = nil
|
lastErr = nil
|
||||||
@@ -281,6 +603,14 @@ attemptLoop:
|
|||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if antigravityShouldRetryTransientResourceExhausted429(httpResp.StatusCode, bodyBytes) && attempt+1 < attempts {
|
||||||
|
delay := antigravityTransient429RetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: transient 429 resource exhausted for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return resp, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
@@ -295,33 +625,21 @@ attemptLoop:
|
|||||||
continue attemptLoop
|
continue attemptLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes))
|
||||||
var param any
|
var param any
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case lastStatus != 0:
|
case lastStatus != 0:
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
err = newAntigravityStatusErr(lastStatus, lastBody)
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
case lastErr != nil:
|
||||||
err = lastErr
|
err = lastErr
|
||||||
default:
|
default:
|
||||||
@@ -345,8 +663,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
auth = updatedAuth
|
auth = updatedAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
@@ -364,8 +682,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -379,7 +697,15 @@ attemptLoop:
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
requestPayload := translated
|
||||||
|
usedCreditsDirect := false
|
||||||
|
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
|
||||||
|
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
|
||||||
|
requestPayload = creditsPayload
|
||||||
|
usedCreditsDirect = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -387,7 +713,7 @@ attemptLoop:
|
|||||||
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
return resp, errDo
|
return resp, errDo
|
||||||
}
|
}
|
||||||
@@ -401,14 +727,14 @@ attemptLoop:
|
|||||||
err = errDo
|
err = errDo
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||||
err = errRead
|
err = errRead
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -427,7 +753,24 @@ attemptLoop:
|
|||||||
err = errRead
|
err = errRead
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if usedCreditsDirect {
|
||||||
|
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
|
||||||
|
clearAntigravityPreferCredits(auth, baseModel)
|
||||||
|
markAntigravityCreditsExhausted(auth, time.Now())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
|
||||||
|
if creditsResp != nil {
|
||||||
|
httpResp = creditsResp
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||||
|
goto streamSuccessClaudeNonStream
|
||||||
|
}
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
lastErr = nil
|
lastErr = nil
|
||||||
@@ -435,6 +778,14 @@ attemptLoop:
|
|||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if antigravityShouldRetryTransientResourceExhausted429(httpResp.StatusCode, bodyBytes) && attempt+1 < attempts {
|
||||||
|
delay := antigravityTransient429RetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: transient 429 resource exhausted for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return resp, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
@@ -449,16 +800,11 @@ attemptLoop:
|
|||||||
continue attemptLoop
|
continue attemptLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
streamSuccessClaudeNonStream:
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
go func(resp *http.Response) {
|
go func(resp *http.Response) {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
@@ -471,29 +817,29 @@ attemptLoop:
|
|||||||
scanner.Buffer(nil, streamScannerBuffer)
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
// Filter usage metadata for all models
|
// Filter usage metadata for all models
|
||||||
// Only retain usage statistics in the terminal chunk
|
// Only retain usage statistics in the terminal chunk
|
||||||
line = FilterSSEUsageMetadata(line)
|
line = helps.FilterSSEUsageMetadata(line)
|
||||||
|
|
||||||
payload := jsonPayload(line)
|
payload := helps.JSONPayload(line)
|
||||||
if payload == nil {
|
if payload == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
} else {
|
} else {
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
}
|
}
|
||||||
}(httpResp)
|
}(httpResp)
|
||||||
|
|
||||||
@@ -509,24 +855,18 @@ attemptLoop:
|
|||||||
}
|
}
|
||||||
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
||||||
|
|
||||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
reporter.Publish(ctx, helps.ParseAntigravityUsage(resp.Payload))
|
||||||
var param any
|
var param any
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case lastStatus != 0:
|
case lastStatus != 0:
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
err = newAntigravityStatusErr(lastStatus, lastBody)
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
case lastErr != nil:
|
||||||
err = lastErr
|
err = lastErr
|
||||||
default:
|
default:
|
||||||
@@ -748,8 +1088,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
auth = updatedAuth
|
auth = updatedAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
@@ -767,8 +1107,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -782,14 +1122,22 @@ attemptLoop:
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
requestPayload := translated
|
||||||
|
usedCreditsDirect := false
|
||||||
|
if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) {
|
||||||
|
if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 {
|
||||||
|
requestPayload = creditsPayload
|
||||||
|
usedCreditsDirect = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
return nil, errDo
|
return nil, errDo
|
||||||
}
|
}
|
||||||
@@ -803,14 +1151,14 @@ attemptLoop:
|
|||||||
err = errDo
|
err = errDo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||||
err = errRead
|
err = errRead
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -829,7 +1177,24 @@ attemptLoop:
|
|||||||
err = errRead
|
err = errRead
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if usedCreditsDirect {
|
||||||
|
if shouldMarkAntigravityCreditsExhausted(httpResp.StatusCode, bodyBytes, nil) {
|
||||||
|
clearAntigravityPreferCredits(auth, baseModel)
|
||||||
|
markAntigravityCreditsExhausted(auth, time.Now())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes)
|
||||||
|
if creditsResp != nil {
|
||||||
|
httpResp = creditsResp
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||||
|
goto streamSuccessExecuteStream
|
||||||
|
}
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
lastErr = nil
|
lastErr = nil
|
||||||
@@ -837,6 +1202,14 @@ attemptLoop:
|
|||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if antigravityShouldRetryTransientResourceExhausted429(httpResp.StatusCode, bodyBytes) && attempt+1 < attempts {
|
||||||
|
delay := antigravityTransient429RetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: transient 429 resource exhausted for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return nil, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
@@ -851,16 +1224,11 @@ attemptLoop:
|
|||||||
continue attemptLoop
|
continue attemptLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes)
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
streamSuccessExecuteStream:
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
go func(resp *http.Response) {
|
go func(resp *http.Response) {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
@@ -874,19 +1242,19 @@ attemptLoop:
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
// Filter usage metadata for all models
|
// Filter usage metadata for all models
|
||||||
// Only retain usage statistics in the terminal chunk
|
// Only retain usage statistics in the terminal chunk
|
||||||
line = FilterSSEUsageMetadata(line)
|
line = helps.FilterSSEUsageMetadata(line)
|
||||||
|
|
||||||
payload := jsonPayload(line)
|
payload := helps.JSONPayload(line)
|
||||||
if payload == nil {
|
if payload == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
if detail, ok := helps.ParseAntigravityStreamUsage(payload); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||||
@@ -899,11 +1267,11 @@ attemptLoop:
|
|||||||
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
|
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
} else {
|
} else {
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
}
|
}
|
||||||
}(httpResp)
|
}(httpResp)
|
||||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
@@ -911,13 +1279,7 @@ attemptLoop:
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case lastStatus != 0:
|
case lastStatus != 0:
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
err = newAntigravityStatusErr(lastStatus, lastBody)
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
case lastErr != nil:
|
||||||
err = lastErr
|
err = lastErr
|
||||||
default:
|
default:
|
||||||
@@ -1011,8 +1373,13 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
if host := resolveHost(base); host != "" {
|
if host := resolveHost(base); host != "" {
|
||||||
httpReq.Host = host
|
httpReq.Host = host
|
||||||
}
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
|
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: requestURL.String(),
|
URL: requestURL.String(),
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -1026,7 +1393,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
return cliproxyexecutor.Response{}, errDo
|
return cliproxyexecutor.Response{}, errDo
|
||||||
}
|
}
|
||||||
@@ -1040,16 +1407,16 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
return cliproxyexecutor.Response{}, errDo
|
return cliproxyexecutor.Response{}, errDo
|
||||||
}
|
}
|
||||||
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
return cliproxyexecutor.Response{}, errRead
|
return cliproxyexecutor.Response{}, errRead
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
|
||||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||||
@@ -1305,6 +1672,11 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
if host := resolveHost(base); host != "" {
|
if host := resolveHost(base); host != "" {
|
||||||
httpReq.Host = host
|
httpReq.Host = host
|
||||||
}
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -1316,7 +1688,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
if e.cfg != nil && e.cfg.RequestLog {
|
if e.cfg != nil && e.cfg.RequestLog {
|
||||||
payloadLog = []byte(payloadStr)
|
payloadLog = []byte(payloadStr)
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: requestURL.String(),
|
URL: requestURL.String(),
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -1420,7 +1792,7 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return defaultAntigravityAgent
|
return misc.AntigravityUserAgent()
|
||||||
}
|
}
|
||||||
|
|
||||||
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
|
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
|
||||||
@@ -1454,6 +1826,24 @@ func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
|
|||||||
return strings.Contains(msg, "no capacity available")
|
return strings.Contains(msg, "no capacity available")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func antigravityShouldRetryTransientResourceExhausted429(statusCode int, body []byte) bool {
|
||||||
|
if statusCode != http.StatusTooManyRequests {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if classifyAntigravity429(body) != antigravity429Unknown {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String())
|
||||||
|
if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(string(body))
|
||||||
|
return strings.Contains(msg, "resource has been exhausted")
|
||||||
|
}
|
||||||
|
|
||||||
func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
|
func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
|
||||||
if attempt < 0 {
|
if attempt < 0 {
|
||||||
attempt = 0
|
attempt = 0
|
||||||
@@ -1465,6 +1855,17 @@ func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
|
|||||||
return delay
|
return delay
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func antigravityTransient429RetryDelay(attempt int) time.Duration {
|
||||||
|
if attempt < 0 {
|
||||||
|
attempt = 0
|
||||||
|
}
|
||||||
|
delay := time.Duration(attempt+1) * 100 * time.Millisecond
|
||||||
|
if delay > 500*time.Millisecond {
|
||||||
|
delay = 500 * time.Millisecond
|
||||||
|
}
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
func antigravityWait(ctx context.Context, wait time.Duration) error {
|
func antigravityWait(ctx context.Context, wait time.Duration) error {
|
||||||
if wait <= 0 {
|
if wait <= 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -1479,7 +1880,7 @@ func antigravityWait(ctx context.Context, wait time.Duration) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
var antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
|
||||||
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
||||||
return []string{base}
|
return []string{base}
|
||||||
}
|
}
|
||||||
|
|||||||
489
internal/runtime/executor/antigravity_executor_credits_test.go
Normal file
489
internal/runtime/executor/antigravity_executor_credits_test.go
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resetAntigravityCreditsRetryState() {
|
||||||
|
antigravityCreditsExhaustedByAuth = sync.Map{}
|
||||||
|
antigravityPreferCreditsByModel = sync.Map{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifyAntigravity429(t *testing.T) {
|
||||||
|
t.Run("quota exhausted", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||||
|
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
|
||||||
|
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("structured rate limit", func(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
if got := classifyAntigravity429(body); got != antigravity429RateLimited {
|
||||||
|
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429RateLimited)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("structured quota exhausted", func(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "QUOTA_EXHAUSTED"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
if got := classifyAntigravity429(body); got != antigravity429QuotaExhausted {
|
||||||
|
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429QuotaExhausted)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"message":"too many requests"}}`)
|
||||||
|
if got := classifyAntigravity429(body); got != antigravity429Unknown {
|
||||||
|
t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429Unknown)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInjectEnabledCreditTypes(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gemini-2.5-flash","request":{}}`)
|
||||||
|
got := injectEnabledCreditTypes(body)
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("injectEnabledCreditTypes() returned nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(got), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||||
|
t.Fatalf("injectEnabledCreditTypes() = %s, want enabledCreditTypes", string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := injectEnabledCreditTypes([]byte(`not json`)); got != nil {
|
||||||
|
t.Fatalf("injectEnabledCreditTypes() for invalid json = %s, want nil", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldMarkAntigravityCreditsExhausted(t *testing.T) {
|
||||||
|
t.Run("credit errors are marked", func(t *testing.T) {
|
||||||
|
for _, body := range [][]byte{
|
||||||
|
[]byte(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`),
|
||||||
|
[]byte(`{"error":{"message":"minimumCreditAmountForUsage requirement not met"}}`),
|
||||||
|
} {
|
||||||
|
if !shouldMarkAntigravityCreditsExhausted(http.StatusForbidden, body, nil) {
|
||||||
|
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("transient 429 resource exhausted is not marked", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}}`)
|
||||||
|
if shouldMarkAntigravityCreditsExhausted(http.StatusTooManyRequests, body, nil) {
|
||||||
|
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = true, want false", string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("resource exhausted with quota metadata is still marked", func(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted","status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","metadata":{"quotaResetDelay":"1h","model":"claude-sonnet-4-6"}}]}}`)
|
||||||
|
if !shouldMarkAntigravityCreditsExhausted(http.StatusTooManyRequests, body, nil) {
|
||||||
|
t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if shouldMarkAntigravityCreditsExhausted(http.StatusServiceUnavailable, []byte(`{"error":{"message":"credits exhausted"}}`), nil) {
|
||||||
|
t.Fatal("shouldMarkAntigravityCreditsExhausted() = true for 5xx, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityExecute_RetriesTransient429ResourceExhausted(t *testing.T) {
|
||||||
|
resetAntigravityCreditsRetryState()
|
||||||
|
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||||
|
|
||||||
|
var requestCount int
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestCount++
|
||||||
|
switch requestCount {
|
||||||
|
case 1:
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}}`))
|
||||||
|
case 2:
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected request count %d", requestCount)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
exec := NewAntigravityExecutor(&config.Config{RequestRetry: 1})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-transient-429",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"project_id": "project-1",
|
||||||
|
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gemini-2.5-flash",
|
||||||
|
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FormatAntigravity,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(resp.Payload) == 0 {
|
||||||
|
t.Fatal("Execute() returned empty payload")
|
||||||
|
}
|
||||||
|
if requestCount != 2 {
|
||||||
|
t.Fatalf("request count = %d, want 2", requestCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) {
|
||||||
|
resetAntigravityCreditsRetryState()
|
||||||
|
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
requestBodies []string
|
||||||
|
)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
_ = r.Body.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
requestBodies = append(requestBodies, string(body))
|
||||||
|
reqNum := len(requestBodies)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if reqNum == 1 {
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||||
|
t.Fatalf("second request body missing enabledCreditTypes: %s", string(body))
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
exec := NewAntigravityExecutor(&config.Config{
|
||||||
|
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
|
||||||
|
})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-credits-ok",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"project_id": "project-1",
|
||||||
|
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gemini-2.5-flash",
|
||||||
|
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FormatAntigravity,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(resp.Payload) == 0 {
|
||||||
|
t.Fatal("Execute() returned empty payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(requestBodies) != 2 {
|
||||||
|
t.Fatalf("request count = %d, want 2", len(requestBodies))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) {
|
||||||
|
resetAntigravityCreditsRetryState()
|
||||||
|
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||||
|
|
||||||
|
var requestCount int
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestCount++
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
exec := NewAntigravityExecutor(&config.Config{
|
||||||
|
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
|
||||||
|
})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-credits-exhausted",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"project_id": "project-1",
|
||||||
|
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
markAntigravityCreditsExhausted(auth, time.Now())
|
||||||
|
|
||||||
|
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gemini-2.5-flash",
|
||||||
|
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FormatAntigravity,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Execute() error = nil, want 429")
|
||||||
|
}
|
||||||
|
sErr, ok := err.(statusErr)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Execute() error type = %T, want statusErr", err)
|
||||||
|
}
|
||||||
|
if got := sErr.StatusCode(); got != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("Execute() status code = %d, want %d", got, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if requestCount != 1 {
|
||||||
|
t.Fatalf("request count = %d, want 1", requestCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityExecute_PrefersCreditsAfterSuccessfulFallback(t *testing.T) {
|
||||||
|
resetAntigravityCreditsRetryState()
|
||||||
|
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
requestBodies []string
|
||||||
|
)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
_ = r.Body.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
requestBodies = append(requestBodies, string(body))
|
||||||
|
reqNum := len(requestBodies)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
switch reqNum {
|
||||||
|
case 1:
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"10s"}]}}`))
|
||||||
|
case 2, 3:
|
||||||
|
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||||
|
t.Fatalf("request %d body missing enabledCreditTypes: %s", reqNum, string(body))
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"OK"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected request count %d", reqNum)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
exec := NewAntigravityExecutor(&config.Config{
|
||||||
|
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
|
||||||
|
})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-prefer-credits",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"project_id": "project-1",
|
||||||
|
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request := cliproxyexecutor.Request{
|
||||||
|
Model: "gemini-2.5-flash",
|
||||||
|
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||||
|
}
|
||||||
|
opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatAntigravity}
|
||||||
|
|
||||||
|
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
|
||||||
|
t.Fatalf("first Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil {
|
||||||
|
t.Fatalf("second Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(requestBodies) != 3 {
|
||||||
|
t.Fatalf("request count = %d, want 3", len(requestBodies))
|
||||||
|
}
|
||||||
|
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||||
|
t.Fatalf("first request unexpectedly used credits: %s", requestBodies[0])
|
||||||
|
}
|
||||||
|
if !strings.Contains(requestBodies[1], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||||
|
t.Fatalf("fallback request missing credits: %s", requestBodies[1])
|
||||||
|
}
|
||||||
|
if !strings.Contains(requestBodies[2], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||||
|
t.Fatalf("preferred request missing credits: %s", requestBodies[2])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityExecute_PreservesBaseURLFallbackAfterCreditsRetryFailure(t *testing.T) {
|
||||||
|
resetAntigravityCreditsRetryState()
|
||||||
|
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
firstCount int
|
||||||
|
secondCount int
|
||||||
|
)
|
||||||
|
|
||||||
|
firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
_ = r.Body.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
firstCount++
|
||||||
|
reqNum := firstCount
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
switch reqNum {
|
||||||
|
case 1:
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"}]}}`))
|
||||||
|
case 2:
|
||||||
|
if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||||
|
t.Fatalf("credits retry missing enabledCreditTypes: %s", string(body))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"message":"permission denied"}}`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected first server request count %d", reqNum)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer firstServer.Close()
|
||||||
|
|
||||||
|
secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mu.Lock()
|
||||||
|
secondCount++
|
||||||
|
mu.Unlock()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`))
|
||||||
|
}))
|
||||||
|
defer secondServer.Close()
|
||||||
|
|
||||||
|
exec := NewAntigravityExecutor(&config.Config{
|
||||||
|
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true},
|
||||||
|
})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-baseurl-fallback",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"base_url": firstServer.URL,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"project_id": "project-1",
|
||||||
|
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
originalOrder := antigravityBaseURLFallbackOrder
|
||||||
|
defer func() { antigravityBaseURLFallbackOrder = originalOrder }()
|
||||||
|
antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string {
|
||||||
|
return []string{firstServer.URL, secondServer.URL}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gemini-2.5-flash",
|
||||||
|
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FormatAntigravity,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(resp.Payload) == 0 {
|
||||||
|
t.Fatal("Execute() returned empty payload")
|
||||||
|
}
|
||||||
|
if firstCount != 2 {
|
||||||
|
t.Fatalf("first server request count = %d, want 2", firstCount)
|
||||||
|
}
|
||||||
|
if secondCount != 1 {
|
||||||
|
t.Fatalf("second server request count = %d, want 1", secondCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityExecute_DoesNotDirectInjectCreditsWhenFlagDisabled(t *testing.T) {
|
||||||
|
resetAntigravityCreditsRetryState()
|
||||||
|
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||||
|
|
||||||
|
var requestBodies []string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
_ = r.Body.Close()
|
||||||
|
requestBodies = append(requestBodies, string(body))
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
exec := NewAntigravityExecutor(&config.Config{
|
||||||
|
QuotaExceeded: config.QuotaExceeded{AntigravityCredits: false},
|
||||||
|
})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "auth-flag-disabled",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"project_id": "project-1",
|
||||||
|
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
markAntigravityPreferCredits(auth, "gemini-2.5-flash", time.Now(), nil)
|
||||||
|
|
||||||
|
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gemini-2.5-flash",
|
||||||
|
Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FormatAntigravity,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Execute() error = nil, want 429")
|
||||||
|
}
|
||||||
|
if len(requestBodies) != 1 {
|
||||||
|
t.Fatalf("request count = %d, want 1", len(requestBodies))
|
||||||
|
}
|
||||||
|
if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) {
|
||||||
|
t.Fatalf("request unexpectedly used enabledCreditTypes with flag disabled: %s", requestBodies[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -4,9 +4,11 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -14,7 +16,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
|
xxHash64 "github.com/pierrec/xxHash/xxHash64"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -23,9 +28,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func resetClaudeDeviceProfileCache() {
|
func resetClaudeDeviceProfileCache() {
|
||||||
claudeDeviceProfileCacheMu.Lock()
|
helps.ResetClaudeDeviceProfileCache()
|
||||||
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
|
||||||
claudeDeviceProfileCacheMu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
|
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
|
||||||
@@ -98,7 +101,7 @@ func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
|
|||||||
req := newClaudeHeaderTestRequest(t, incoming)
|
req := newClaudeHeaderTestRequest(t, incoming)
|
||||||
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
|
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
|
||||||
|
|
||||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
|
assertClaudeFingerprint(t, req.Header, "evil-client/9.9", "9.9.9", "v24.5.0", "Linux", "x64")
|
||||||
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
|
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
|
||||||
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
|
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
|
||||||
}
|
}
|
||||||
@@ -338,7 +341,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
|
|||||||
var pauseOnce sync.Once
|
var pauseOnce sync.Once
|
||||||
var releaseOnce sync.Once
|
var releaseOnce sync.Once
|
||||||
|
|
||||||
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
|
helps.ClaudeDeviceProfileBeforeCandidateStore = func(candidate helps.ClaudeDeviceProfile) {
|
||||||
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
|
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -346,13 +349,13 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
|
|||||||
<-releaseLow
|
<-releaseLow
|
||||||
}
|
}
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
claudeDeviceProfileBeforeCandidateStore = nil
|
helps.ClaudeDeviceProfileBeforeCandidateStore = nil
|
||||||
releaseOnce.Do(func() { close(releaseLow) })
|
releaseOnce.Do(func() { close(releaseLow) })
|
||||||
})
|
})
|
||||||
|
|
||||||
lowResultCh := make(chan claudeDeviceProfile, 1)
|
lowResultCh := make(chan helps.ClaudeDeviceProfile, 1)
|
||||||
go func() {
|
go func() {
|
||||||
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
lowResultCh <- helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||||
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
|
||||||
"X-Stainless-Package-Version": []string{"0.74.0"},
|
"X-Stainless-Package-Version": []string{"0.74.0"},
|
||||||
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
|
||||||
@@ -367,7 +370,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
|
|||||||
t.Fatal("timed out waiting for lower candidate to pause before storing")
|
t.Fatal("timed out waiting for lower candidate to pause before storing")
|
||||||
}
|
}
|
||||||
|
|
||||||
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
highResult := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||||
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
|
||||||
"X-Stainless-Package-Version": []string{"0.75.0"},
|
"X-Stainless-Package-Version": []string{"0.75.0"},
|
||||||
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
|
||||||
@@ -398,7 +401,7 @@ func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testi
|
|||||||
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
|
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
|
||||||
}
|
}
|
||||||
|
|
||||||
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
cached := helps.ResolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
|
||||||
"User-Agent": []string{"curl/8.7.1"},
|
"User-Agent": []string{"curl/8.7.1"},
|
||||||
}, cfg)
|
}, cfg)
|
||||||
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
|
||||||
@@ -564,7 +567,7 @@ func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *tes
|
|||||||
})
|
})
|
||||||
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
|
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
|
||||||
|
|
||||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
|
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
|
||||||
@@ -591,14 +594,14 @@ func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallbac
|
|||||||
})
|
})
|
||||||
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
|
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
|
||||||
|
|
||||||
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
|
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", helps.MapStainlessOS(), helps.MapStainlessArch())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
|
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
|
||||||
if claudeDeviceProfileStabilizationEnabled(nil) {
|
if helps.ClaudeDeviceProfileStabilizationEnabled(nil) {
|
||||||
t.Fatal("expected nil config to default to disabled stabilization")
|
t.Fatal("expected nil config to default to disabled stabilization")
|
||||||
}
|
}
|
||||||
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
|
if helps.ClaudeDeviceProfileStabilizationEnabled(&config.Config{}) {
|
||||||
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
|
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -736,6 +739,35 @@ func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_KnownFallbackBuiltinsRemainUnprefixed(t *testing.T) {
|
||||||
|
for _, builtin := range []string{"web_search", "code_execution", "text_editor", "computer"} {
|
||||||
|
t.Run(builtin, func(t *testing.T) {
|
||||||
|
input := []byte(fmt.Sprintf(`{
|
||||||
|
"tools":[{"name":"Read"}],
|
||||||
|
"tool_choice":{"type":"tool","name":%q},
|
||||||
|
"messages":[{"role":"assistant","content":[{"type":"tool_use","name":%q,"id":"toolu_1","input":{}},{"type":"tool_reference","tool_name":%q},{"type":"tool_result","tool_use_id":"toolu_1","content":[{"type":"tool_reference","tool_name":%q}]}]}]
|
||||||
|
}`, builtin, builtin, builtin, builtin))
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != builtin {
|
||||||
|
t.Fatalf("tool_choice.name = %q, want %q", got, builtin)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != builtin {
|
||||||
|
t.Fatalf("messages.0.content.0.name = %q, want %q", got, builtin)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != builtin {
|
||||||
|
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, builtin)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.2.content.0.tool_name").String(); got != builtin {
|
||||||
|
t.Fatalf("messages.0.content.2.content.0.tool_name = %q, want %q", got, builtin)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||||
|
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
@@ -796,8 +828,6 @@ func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
||||||
resetUserIDCache()
|
|
||||||
|
|
||||||
var userIDs []string
|
var userIDs []string
|
||||||
var requestModels []string
|
var requestModels []string
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -857,15 +887,13 @@ func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
|||||||
if userIDs[0] != userIDs[1] {
|
if userIDs[0] != userIDs[1] {
|
||||||
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
|
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
|
||||||
}
|
}
|
||||||
if !isValidUserID(userIDs[0]) {
|
if !helps.IsValidUserID(userIDs[0]) {
|
||||||
t.Fatalf("user_id %q is not valid", userIDs[0])
|
t.Fatalf("user_id %q is not valid", userIDs[0])
|
||||||
}
|
}
|
||||||
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
|
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||||
resetUserIDCache()
|
|
||||||
|
|
||||||
var userIDs []string
|
var userIDs []string
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
body, _ := io.ReadAll(r.Body)
|
body, _ := io.ReadAll(r.Body)
|
||||||
@@ -903,7 +931,7 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
|||||||
if userIDs[0] == userIDs[1] {
|
if userIDs[0] == userIDs[1] {
|
||||||
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
|
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
|
||||||
}
|
}
|
||||||
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
|
if !helps.IsValidUserID(userIDs[0]) || !helps.IsValidUserID(userIDs[1]) {
|
||||||
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
|
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -966,6 +994,28 @@ func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeCacheControlTTL_PreservesKeyOrderWhenModified(t *testing.T) {
|
||||||
|
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
|
||||||
|
|
||||||
|
out := normalizeCacheControlTTL(payload)
|
||||||
|
|
||||||
|
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
|
||||||
|
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
|
||||||
|
}
|
||||||
|
|
||||||
|
outStr := string(out)
|
||||||
|
idxModel := strings.Index(outStr, `"model"`)
|
||||||
|
idxMessages := strings.Index(outStr, `"messages"`)
|
||||||
|
idxTools := strings.Index(outStr, `"tools"`)
|
||||||
|
idxSystem := strings.Index(outStr, `"system"`)
|
||||||
|
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
|
||||||
|
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
|
||||||
|
}
|
||||||
|
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
|
||||||
|
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||||
payload := []byte(`{
|
payload := []byte(`{
|
||||||
"tools": [
|
"tools": [
|
||||||
@@ -995,6 +1045,31 @@ func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnforceCacheControlLimit_PreservesKeyOrderWhenModified(t *testing.T) {
|
||||||
|
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}},{"name":"t2","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
|
||||||
|
|
||||||
|
out := enforceCacheControlLimit(payload, 4)
|
||||||
|
|
||||||
|
if got := countCacheControls(out); got != 4 {
|
||||||
|
t.Fatalf("cache_control count = %d, want 4", got)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||||
|
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
|
||||||
|
}
|
||||||
|
|
||||||
|
outStr := string(out)
|
||||||
|
idxModel := strings.Index(outStr, `"model"`)
|
||||||
|
idxMessages := strings.Index(outStr, `"messages"`)
|
||||||
|
idxTools := strings.Index(outStr, `"tools"`)
|
||||||
|
idxSystem := strings.Index(outStr, `"system"`)
|
||||||
|
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
|
||||||
|
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
|
||||||
|
}
|
||||||
|
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
|
||||||
|
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
|
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
|
||||||
payload := []byte(`{
|
payload := []byte(`{
|
||||||
"tools": [
|
"tools": [
|
||||||
@@ -1183,6 +1258,83 @@ func testClaudeExecutorInvalidCompressedErrorBody(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsureModelMaxTokens_UsesRegisteredMaxCompletionTokens(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
clientID := "test-claude-max-completion-tokens-client"
|
||||||
|
modelID := "test-claude-max-completion-tokens-model"
|
||||||
|
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||||
|
ID: modelID,
|
||||||
|
Type: "claude",
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
MaxCompletionTokens: 4096,
|
||||||
|
UserDefined: true,
|
||||||
|
}})
|
||||||
|
defer reg.UnregisterClient(clientID)
|
||||||
|
|
||||||
|
input := []byte(`{"model":"test-claude-max-completion-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := ensureModelMaxTokens(input, modelID)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 4096 {
|
||||||
|
t.Fatalf("max_tokens = %d, want %d", got, 4096)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureModelMaxTokens_DefaultsMissingValue(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
clientID := "test-claude-default-max-tokens-client"
|
||||||
|
modelID := "test-claude-default-max-tokens-model"
|
||||||
|
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||||
|
ID: modelID,
|
||||||
|
Type: "claude",
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
UserDefined: true,
|
||||||
|
}})
|
||||||
|
defer reg.UnregisterClient(clientID)
|
||||||
|
|
||||||
|
input := []byte(`{"model":"test-claude-default-max-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := ensureModelMaxTokens(input, modelID)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "max_tokens").Int(); got != defaultModelMaxTokens {
|
||||||
|
t.Fatalf("max_tokens = %d, want %d", got, defaultModelMaxTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureModelMaxTokens_PreservesExplicitValue(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
clientID := "test-claude-preserve-max-tokens-client"
|
||||||
|
modelID := "test-claude-preserve-max-tokens-model"
|
||||||
|
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||||
|
ID: modelID,
|
||||||
|
Type: "claude",
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
MaxCompletionTokens: 4096,
|
||||||
|
UserDefined: true,
|
||||||
|
}})
|
||||||
|
defer reg.UnregisterClient(clientID)
|
||||||
|
|
||||||
|
input := []byte(`{"model":"test-claude-preserve-max-tokens-model","max_tokens":2048,"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := ensureModelMaxTokens(input, modelID)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 2048 {
|
||||||
|
t.Fatalf("max_tokens = %d, want %d", got, 2048)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureModelMaxTokens_SkipsUnregisteredModel(t *testing.T) {
|
||||||
|
input := []byte(`{"model":"test-claude-unregistered-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := ensureModelMaxTokens(input, "test-claude-unregistered-model")
|
||||||
|
|
||||||
|
if gjson.GetBytes(out, "max_tokens").Exists() {
|
||||||
|
t.Fatalf("max_tokens should remain unset, got %s", gjson.GetBytes(out, "max_tokens").Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
||||||
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
||||||
// compressed SSE body that would silently break the line scanner.
|
// compressed SSE body that would silently break the line scanner.
|
||||||
@@ -1340,6 +1492,35 @@ func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
|
||||||
|
// detects zstd-compressed content via magic bytes even when Content-Encoding is absent.
|
||||||
|
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
|
||||||
|
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc, err := zstd.NewWriter(&buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("zstd.NewWriter: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = enc.Write([]byte(plaintext))
|
||||||
|
_ = enc.Close()
|
||||||
|
|
||||||
|
rc := io.NopCloser(&buf)
|
||||||
|
decoded, err := decodeResponseBody(rc, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeResponseBody error: %v", err)
|
||||||
|
}
|
||||||
|
defer decoded.Close()
|
||||||
|
|
||||||
|
got, err := io.ReadAll(decoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != plaintext {
|
||||||
|
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
|
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
|
||||||
// plain text untouched when Content-Encoding is absent and no magic bytes match.
|
// plain text untouched when Content-Encoding is absent and no magic bytes match.
|
||||||
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
|
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
|
||||||
@@ -1411,77 +1592,6 @@ func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
|
|
||||||
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
|
|
||||||
// path's enforced identity encoding.
|
|
||||||
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
|
|
||||||
var gotEncoding string
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
executor := NewClaudeExecutor(&config.Config{})
|
|
||||||
// Inject Accept-Encoding via the custom header attribute mechanism.
|
|
||||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
|
||||||
"api_key": "key-123",
|
|
||||||
"base_url": server.URL,
|
|
||||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
|
||||||
}}
|
|
||||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
|
||||||
|
|
||||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
|
||||||
Model: "claude-3-5-sonnet-20241022",
|
|
||||||
Payload: payload,
|
|
||||||
}, cliproxyexecutor.Options{
|
|
||||||
SourceFormat: sdktranslator.FromString("claude"),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExecuteStream error: %v", err)
|
|
||||||
}
|
|
||||||
for chunk := range result.Chunks {
|
|
||||||
if chunk.Err != nil {
|
|
||||||
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotEncoding != "identity" {
|
|
||||||
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
|
|
||||||
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
|
|
||||||
// Content-Encoding is absent.
|
|
||||||
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
|
|
||||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
enc, err := zstd.NewWriter(&buf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("zstd.NewWriter: %v", err)
|
|
||||||
}
|
|
||||||
_, _ = enc.Write([]byte(plaintext))
|
|
||||||
_ = enc.Close()
|
|
||||||
|
|
||||||
rc := io.NopCloser(&buf)
|
|
||||||
decoded, err := decodeResponseBody(rc, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decodeResponseBody error: %v", err)
|
|
||||||
}
|
|
||||||
defer decoded.Close()
|
|
||||||
|
|
||||||
got, err := io.ReadAll(decoded)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadAll error: %v", err)
|
|
||||||
}
|
|
||||||
if string(got) != plaintext {
|
|
||||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
|
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
|
||||||
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
|
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
|
||||||
// the Content-Encoding header. This closes the gap left by PR #1771, which only
|
// the Content-Encoding header. This closes the gap left by PR #1771, which only
|
||||||
@@ -1565,6 +1675,45 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies that the
|
||||||
|
// streaming executor enforces Accept-Encoding: identity regardless of auth.Attributes override.
|
||||||
|
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
|
||||||
|
var gotEncoding string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream error: %v", err)
|
||||||
|
}
|
||||||
|
for chunk := range result.Chunks {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotEncoding != "identity" {
|
||||||
|
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Test case 1: String system prompt is preserved and converted to a content block
|
// Test case 1: String system prompt is preserved and converted to a content block
|
||||||
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
|
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
|
||||||
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
@@ -1648,3 +1797,155 @@ func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
|
|||||||
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
|
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_ExperimentalCCHSigningDisabledByDefaultKeepsLegacyHeader(t *testing.T) {
|
||||||
|
var seenBody []byte
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
seenBody = bytes.Clone(body)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(seenBody) == 0 {
|
||||||
|
t.Fatal("expected request body to be captured")
|
||||||
|
}
|
||||||
|
|
||||||
|
billingHeader := gjson.GetBytes(seenBody, "system.0.text").String()
|
||||||
|
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
|
||||||
|
t.Fatalf("system.0.text = %q, want billing header", billingHeader)
|
||||||
|
}
|
||||||
|
if strings.Contains(billingHeader, "cch=00000;") {
|
||||||
|
t.Fatalf("legacy mode should not forward cch placeholder, got %q", billingHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_ExperimentalCCHSigningOptInSignsFinalBody(t *testing.T) {
|
||||||
|
var seenBody []byte
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
seenBody = bytes.Clone(body)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{
|
||||||
|
ClaudeKey: []config.ClaudeKey{{
|
||||||
|
APIKey: "key-123",
|
||||||
|
BaseURL: server.URL,
|
||||||
|
ExperimentalCCHSigning: true,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
const messageText = "please keep literal cch=00000 in this message"
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"please keep literal cch=00000 in this message"}]}]}`)
|
||||||
|
|
||||||
|
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(seenBody) == 0 {
|
||||||
|
t.Fatal("expected request body to be captured")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(seenBody, "messages.0.content.0.text").String(); got != messageText {
|
||||||
|
t.Fatalf("message text = %q, want %q", got, messageText)
|
||||||
|
}
|
||||||
|
|
||||||
|
billingPattern := regexp.MustCompile(`(x-anthropic-billing-header:[^"]*?\bcch=)([0-9a-f]{5})(;)`)
|
||||||
|
match := billingPattern.FindSubmatch(seenBody)
|
||||||
|
if match == nil {
|
||||||
|
t.Fatalf("expected signed billing header in body: %s", string(seenBody))
|
||||||
|
}
|
||||||
|
actualCCH := string(match[2])
|
||||||
|
unsignedBody := billingPattern.ReplaceAll(seenBody, []byte(`${1}00000${3}`))
|
||||||
|
wantCCH := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, 0x6E52736AC806831E)&0xFFFFF)
|
||||||
|
if actualCCH != wantCCH {
|
||||||
|
t.Fatalf("cch = %q, want %q\nbody: %s", actualCCH, wantCCH, string(seenBody))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCloaking_PreservesConfiguredStrictModeAndSensitiveWordsWhenModeOmitted(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
ClaudeKey: []config.ClaudeKey{{
|
||||||
|
APIKey: "key-123",
|
||||||
|
Cloak: &config.CloakConfig{
|
||||||
|
StrictMode: true,
|
||||||
|
SensitiveWords: []string{"proxy"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "key-123"}}
|
||||||
|
payload := []byte(`{"system":"proxy rules","messages":[{"role":"user","content":[{"type":"text","text":"proxy access"}]}]}`)
|
||||||
|
|
||||||
|
out := applyCloaking(context.Background(), cfg, auth, payload, "claude-3-5-sonnet-20241022", "key-123")
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 2 {
|
||||||
|
t.Fatalf("expected strict mode to keep only injected system blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.0.text").String(); !strings.Contains(got, "\u200B") {
|
||||||
|
t.Fatalf("expected configured sensitive word obfuscation to apply, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeTemperatureForThinking_AdaptiveCoercesToOne(t *testing.T) {
|
||||||
|
payload := []byte(`{"temperature":0,"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`)
|
||||||
|
out := normalizeClaudeTemperatureForThinking(payload)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "temperature").Float(); got != 1 {
|
||||||
|
t.Fatalf("temperature = %v, want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeTemperatureForThinking_EnabledCoercesToOne(t *testing.T) {
|
||||||
|
payload := []byte(`{"temperature":0.2,"thinking":{"type":"enabled","budget_tokens":2048}}`)
|
||||||
|
out := normalizeClaudeTemperatureForThinking(payload)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "temperature").Float(); got != 1 {
|
||||||
|
t.Fatalf("temperature = %v, want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeTemperatureForThinking_NoThinkingLeavesTemperatureAlone(t *testing.T) {
|
||||||
|
payload := []byte(`{"temperature":0,"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
out := normalizeClaudeTemperatureForThinking(payload)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "temperature").Float(); got != 0 {
|
||||||
|
t.Fatalf("temperature = %v, want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOriginalTemperature(t *testing.T) {
|
||||||
|
payload := []byte(`{"temperature":0,"thinking":{"type":"adaptive"},"output_config":{"effort":"max"},"tool_choice":{"type":"any"}}`)
|
||||||
|
out := disableThinkingIfToolChoiceForced(payload)
|
||||||
|
out = normalizeClaudeTemperatureForThinking(out)
|
||||||
|
|
||||||
|
if gjson.GetBytes(out, "thinking").Exists() {
|
||||||
|
t.Fatalf("thinking should be removed when tool_choice forces tool use")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "temperature").Float(); got != 0 {
|
||||||
|
t.Fatalf("temperature = %v, want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
81
internal/runtime/executor/claude_signing.go
Normal file
81
internal/runtime/executor/claude_signing.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
xxHash64 "github.com/pierrec/xxHash/xxHash64"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const claudeCCHSeed uint64 = 0x6E52736AC806831E
|
||||||
|
|
||||||
|
var claudeBillingHeaderCCHPattern = regexp.MustCompile(`\bcch=([0-9a-f]{5});`)
|
||||||
|
|
||||||
|
func signAnthropicMessagesBody(body []byte) []byte {
|
||||||
|
billingHeader := gjson.GetBytes(body, "system.0.text").String()
|
||||||
|
if !strings.HasPrefix(billingHeader, "x-anthropic-billing-header:") {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if !claudeBillingHeaderCCHPattern.MatchString(billingHeader) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
unsignedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(billingHeader, "cch=00000;")
|
||||||
|
unsignedBody, err := sjson.SetBytes(body, "system.0.text", unsignedBillingHeader)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
cch := fmt.Sprintf("%05x", xxHash64.Checksum(unsignedBody, claudeCCHSeed)&0xFFFFF)
|
||||||
|
signedBillingHeader := claudeBillingHeaderCCHPattern.ReplaceAllString(unsignedBillingHeader, "cch="+cch+";")
|
||||||
|
signedBody, err := sjson.SetBytes(unsignedBody, "system.0.text", signedBillingHeader)
|
||||||
|
if err != nil {
|
||||||
|
return unsignedBody
|
||||||
|
}
|
||||||
|
return signedBody
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveClaudeKeyConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.ClaudeKey {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, baseURL := claudeCreds(auth)
|
||||||
|
if apiKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cfg.ClaudeKey {
|
||||||
|
entry := &cfg.ClaudeKey[i]
|
||||||
|
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||||
|
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||||
|
if !strings.EqualFold(cfgKey, apiKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||||
|
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
|
||||||
|
entry := resolveClaudeKeyConfig(cfg, auth)
|
||||||
|
if entry == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return entry.Cloak
|
||||||
|
}
|
||||||
|
|
||||||
|
func experimentalCCHSigningEnabled(cfg *config.Config, auth *cliproxyauth.Auth) bool {
|
||||||
|
entry := resolveClaudeKeyConfig(cfg, auth)
|
||||||
|
return entry != nil && entry.ExperimentalCCHSigning
|
||||||
|
}
|
||||||
@@ -4,9 +4,11 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||||
@@ -14,8 +16,11 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -98,10 +103,12 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
if len(opts.OriginalRequest) > 0 {
|
if len(opts.OriginalRequest) > 0 {
|
||||||
originalPayloadSource = opts.OriginalRequest
|
originalPayloadSource = opts.OriginalRequest
|
||||||
}
|
}
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
|
translated, _ = sjson.SetBytes(translated, "stream", true)
|
||||||
|
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -114,6 +121,8 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
e.applyHeaders(httpReq, accessToken, userID, domain)
|
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||||
|
httpReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -160,11 +169,16 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
aggregatedBody, usageDetail, err := aggregateOpenAIChatCompletionStream(body)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
reporter.publish(ctx, usageDetail)
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, aggregatedBody, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -341,3 +355,197 @@ func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID,
|
|||||||
req.Header.Set("X-IDE-Version", "2.63.2")
|
req.Header.Set("X-IDE-Version", "2.63.2")
|
||||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIChatStreamChoiceAccumulator struct {
|
||||||
|
Role string
|
||||||
|
ContentParts []string
|
||||||
|
ReasoningParts []string
|
||||||
|
FinishReason string
|
||||||
|
ToolCalls map[int]*openAIChatStreamToolCallAccumulator
|
||||||
|
ToolCallOrder []int
|
||||||
|
NativeFinishReason any
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIChatStreamToolCallAccumulator struct {
|
||||||
|
ID string
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
Arguments strings.Builder
|
||||||
|
}
|
||||||
|
|
||||||
|
func aggregateOpenAIChatCompletionStream(raw []byte) ([]byte, usage.Detail, error) {
|
||||||
|
lines := bytes.Split(raw, []byte("\n"))
|
||||||
|
var (
|
||||||
|
responseID string
|
||||||
|
model string
|
||||||
|
created int64
|
||||||
|
serviceTier string
|
||||||
|
systemFP string
|
||||||
|
usageDetail usage.Detail
|
||||||
|
choices = map[int]*openAIChatStreamChoiceAccumulator{}
|
||||||
|
choiceOrder []int
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := bytes.TrimSpace(line[5:])
|
||||||
|
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(payload) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
root := gjson.ParseBytes(payload)
|
||||||
|
if responseID == "" {
|
||||||
|
responseID = root.Get("id").String()
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
model = root.Get("model").String()
|
||||||
|
}
|
||||||
|
if created == 0 {
|
||||||
|
created = root.Get("created").Int()
|
||||||
|
}
|
||||||
|
if serviceTier == "" {
|
||||||
|
serviceTier = root.Get("service_tier").String()
|
||||||
|
}
|
||||||
|
if systemFP == "" {
|
||||||
|
systemFP = root.Get("system_fingerprint").String()
|
||||||
|
}
|
||||||
|
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||||
|
usageDetail = detail
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, choiceResult := range root.Get("choices").Array() {
|
||||||
|
idx := int(choiceResult.Get("index").Int())
|
||||||
|
choice := choices[idx]
|
||||||
|
if choice == nil {
|
||||||
|
choice = &openAIChatStreamChoiceAccumulator{ToolCalls: map[int]*openAIChatStreamToolCallAccumulator{}}
|
||||||
|
choices[idx] = choice
|
||||||
|
choiceOrder = append(choiceOrder, idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := choiceResult.Get("delta")
|
||||||
|
if role := delta.Get("role").String(); role != "" {
|
||||||
|
choice.Role = role
|
||||||
|
}
|
||||||
|
if content := delta.Get("content").String(); content != "" {
|
||||||
|
choice.ContentParts = append(choice.ContentParts, content)
|
||||||
|
}
|
||||||
|
if reasoning := delta.Get("reasoning_content").String(); reasoning != "" {
|
||||||
|
choice.ReasoningParts = append(choice.ReasoningParts, reasoning)
|
||||||
|
}
|
||||||
|
if finishReason := choiceResult.Get("finish_reason").String(); finishReason != "" {
|
||||||
|
choice.FinishReason = finishReason
|
||||||
|
}
|
||||||
|
if nativeFinishReason := choiceResult.Get("native_finish_reason"); nativeFinishReason.Exists() {
|
||||||
|
choice.NativeFinishReason = nativeFinishReason.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, toolCallResult := range delta.Get("tool_calls").Array() {
|
||||||
|
toolIdx := int(toolCallResult.Get("index").Int())
|
||||||
|
toolCall := choice.ToolCalls[toolIdx]
|
||||||
|
if toolCall == nil {
|
||||||
|
toolCall = &openAIChatStreamToolCallAccumulator{}
|
||||||
|
choice.ToolCalls[toolIdx] = toolCall
|
||||||
|
choice.ToolCallOrder = append(choice.ToolCallOrder, toolIdx)
|
||||||
|
}
|
||||||
|
if id := toolCallResult.Get("id").String(); id != "" {
|
||||||
|
toolCall.ID = id
|
||||||
|
}
|
||||||
|
if typ := toolCallResult.Get("type").String(); typ != "" {
|
||||||
|
toolCall.Type = typ
|
||||||
|
}
|
||||||
|
if name := toolCallResult.Get("function.name").String(); name != "" {
|
||||||
|
toolCall.Name = name
|
||||||
|
}
|
||||||
|
if args := toolCallResult.Get("function.arguments").String(); args != "" {
|
||||||
|
toolCall.Arguments.WriteString(args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseID == "" && model == "" && len(choiceOrder) == 0 {
|
||||||
|
return nil, usageDetail, fmt.Errorf("codebuddy: streaming response did not contain any chat completion chunks")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]any{
|
||||||
|
"id": responseID,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": make([]map[string]any, 0, len(choiceOrder)),
|
||||||
|
"usage": map[string]any{
|
||||||
|
"prompt_tokens": usageDetail.InputTokens,
|
||||||
|
"completion_tokens": usageDetail.OutputTokens,
|
||||||
|
"total_tokens": usageDetail.TotalTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if serviceTier != "" {
|
||||||
|
response["service_tier"] = serviceTier
|
||||||
|
}
|
||||||
|
if systemFP != "" {
|
||||||
|
response["system_fingerprint"] = systemFP
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, idx := range choiceOrder {
|
||||||
|
choice := choices[idx]
|
||||||
|
message := map[string]any{
|
||||||
|
"role": choice.Role,
|
||||||
|
"content": strings.Join(choice.ContentParts, ""),
|
||||||
|
}
|
||||||
|
if message["role"] == "" {
|
||||||
|
message["role"] = "assistant"
|
||||||
|
}
|
||||||
|
if len(choice.ReasoningParts) > 0 {
|
||||||
|
message["reasoning_content"] = strings.Join(choice.ReasoningParts, "")
|
||||||
|
}
|
||||||
|
if len(choice.ToolCallOrder) > 0 {
|
||||||
|
toolCalls := make([]map[string]any, 0, len(choice.ToolCallOrder))
|
||||||
|
for _, toolIdx := range choice.ToolCallOrder {
|
||||||
|
toolCall := choice.ToolCalls[toolIdx]
|
||||||
|
toolCallType := toolCall.Type
|
||||||
|
if toolCallType == "" {
|
||||||
|
toolCallType = "function"
|
||||||
|
}
|
||||||
|
arguments := toolCall.Arguments.String()
|
||||||
|
if arguments == "" {
|
||||||
|
arguments = "{}"
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, map[string]any{
|
||||||
|
"id": toolCall.ID,
|
||||||
|
"type": toolCallType,
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": toolCall.Name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
message["tool_calls"] = toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := choice.FinishReason
|
||||||
|
if finishReason == "" {
|
||||||
|
finishReason = "stop"
|
||||||
|
}
|
||||||
|
choicePayload := map[string]any{
|
||||||
|
"index": idx,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": finishReason,
|
||||||
|
}
|
||||||
|
if choice.NativeFinishReason != nil {
|
||||||
|
choicePayload["native_finish_reason"] = choice.NativeFinishReason
|
||||||
|
}
|
||||||
|
response["choices"] = append(response["choices"].([]map[string]any), choicePayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return nil, usageDetail, fmt.Errorf("codebuddy: failed to encode aggregated response: %w", err)
|
||||||
|
}
|
||||||
|
return out, usageDetail, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,12 +7,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -28,8 +30,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
|
||||||
codexOriginator = "codex_cli_rs"
|
codexOriginator = "codex-tui"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dataTag = []byte("data:")
|
var dataTag = []byte("data:")
|
||||||
@@ -73,7 +75,7 @@ func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,8 +90,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
@@ -106,16 +108,15 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
body = normalizeCodexInstructions(body)
|
||||||
}
|
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||||
@@ -129,7 +130,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -140,10 +141,10 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
AuthType: authType,
|
AuthType: authType,
|
||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -151,38 +152,79 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
|
||||||
lines := bytes.Split(data, []byte("\n"))
|
lines := bytes.Split(data, []byte("\n"))
|
||||||
|
outputItemsByIndex := make(map[int64][]byte)
|
||||||
|
var outputItemsFallback [][]byte
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if !bytes.HasPrefix(line, dataTag) {
|
if !bytes.HasPrefix(line, dataTag) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
line = bytes.TrimSpace(line[5:])
|
eventData := bytes.TrimSpace(line[5:])
|
||||||
if gjson.GetBytes(line, "type").String() != "response.completed" {
|
eventType := gjson.GetBytes(eventData, "type").String()
|
||||||
|
|
||||||
|
if eventType == "response.output_item.done" {
|
||||||
|
itemResult := gjson.GetBytes(eventData, "item")
|
||||||
|
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
outputIndexResult := gjson.GetBytes(eventData, "output_index")
|
||||||
|
if outputIndexResult.Exists() {
|
||||||
|
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
|
||||||
|
} else {
|
||||||
|
outputItemsFallback = append(outputItemsFallback, []byte(itemResult.Raw))
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if detail, ok := parseCodexUsage(line); ok {
|
if eventType != "response.completed" {
|
||||||
reporter.publish(ctx, detail)
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||||
|
reporter.Publish(ctx, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
completedData := eventData
|
||||||
|
outputResult := gjson.GetBytes(completedData, "response.output")
|
||||||
|
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
|
||||||
|
if shouldPatchOutput {
|
||||||
|
completedDataPatched := completedData
|
||||||
|
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output", []byte(`[]`))
|
||||||
|
|
||||||
|
indexes := make([]int64, 0, len(outputItemsByIndex))
|
||||||
|
for idx := range outputItemsByIndex {
|
||||||
|
indexes = append(indexes, idx)
|
||||||
|
}
|
||||||
|
sort.Slice(indexes, func(i, j int) bool {
|
||||||
|
return indexes[i] < indexes[j]
|
||||||
|
})
|
||||||
|
for _, idx := range indexes {
|
||||||
|
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", outputItemsByIndex[idx])
|
||||||
|
}
|
||||||
|
for _, item := range outputItemsFallback {
|
||||||
|
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", item)
|
||||||
|
}
|
||||||
|
completedData = completedDataPatched
|
||||||
}
|
}
|
||||||
|
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, completedData, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -198,8 +240,8 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai-response")
|
to := sdktranslator.FromString("openai-response")
|
||||||
@@ -216,10 +258,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.DeleteBytes(body, "stream")
|
body, _ = sjson.DeleteBytes(body, "stream")
|
||||||
|
body = normalizeCodexInstructions(body)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||||
@@ -233,7 +276,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -244,10 +287,10 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
AuthType: authType,
|
AuthType: authType,
|
||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -255,22 +298,22 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = newCodexStatusErr(httpResp.StatusCode, b)
|
err = newCodexStatusErr(httpResp.StatusCode, b)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
@@ -288,8 +331,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
@@ -306,15 +349,14 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
body = normalizeCodexInstructions(body)
|
||||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||||
@@ -328,7 +370,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -340,24 +382,24 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
data, readErr := io.ReadAll(httpResp.Body)
|
data, readErr := io.ReadAll(httpResp.Body)
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
helps.RecordAPIResponseError(ctx, e.cfg, readErr)
|
||||||
return nil, readErr
|
return nil, readErr
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = newCodexStatusErr(httpResp.StatusCode, data)
|
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -374,13 +416,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
if bytes.HasPrefix(line, dataTag) {
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
data := bytes.TrimSpace(line[5:])
|
data := bytes.TrimSpace(line[5:])
|
||||||
if gjson.GetBytes(data, "type").String() == "response.completed" {
|
if gjson.GetBytes(data, "type").String() == "response.completed" {
|
||||||
if detail, ok := parseCodexUsage(data); ok {
|
if detail, ok := helps.ParseCodexUsage(data); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -391,8 +433,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -415,10 +457,9 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||||
body, _ = sjson.SetBytes(body, "stream", false)
|
body, _ = sjson.SetBytes(body, "stream", false)
|
||||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
body = normalizeCodexInstructions(body)
|
||||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
enc, err := tokenizerForCodexModel(baseModel)
|
enc, err := tokenizerForCodexModel(baseModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -597,18 +638,18 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
|
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
|
||||||
var cache codexCache
|
var cache helps.CodexCache
|
||||||
if from == "claude" {
|
if from == "claude" {
|
||||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||||
if userIDResult.Exists() {
|
if userIDResult.Exists() {
|
||||||
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
||||||
var ok bool
|
var ok bool
|
||||||
if cache, ok = getCodexCache(key); !ok {
|
if cache, ok = helps.GetCodexCache(key); !ok {
|
||||||
cache = codexCache{
|
cache = helps.CodexCache{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Expire: time.Now().Add(1 * time.Hour),
|
Expire: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
setCodexCache(key, cache)
|
helps.SetCodexCache(key, cache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if from == "openai-response" {
|
} else if from == "openai-response" {
|
||||||
@@ -617,7 +658,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
|||||||
cache.ID = promptCacheKey.String()
|
cache.ID = promptCacheKey.String()
|
||||||
}
|
}
|
||||||
} else if from == "openai" {
|
} else if from == "openai" {
|
||||||
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
|
if apiKey := strings.TrimSpace(helps.APIKeyFromContext(ctx)); apiKey != "" {
|
||||||
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -630,7 +671,6 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if cache.ID != "" {
|
if cache.ID != "" {
|
||||||
httpReq.Header.Set("Conversation_id", cache.ID)
|
|
||||||
httpReq.Header.Set("Session_id", cache.ID)
|
httpReq.Header.Set("Session_id", cache.ID)
|
||||||
}
|
}
|
||||||
return httpReq, nil
|
return httpReq, nil
|
||||||
@@ -645,13 +685,19 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
|||||||
ginHeaders = ginCtx.Request.Header
|
ginHeaders = ginCtx.Request.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ginHeaders.Get("X-Codex-Beta-Features") != "" {
|
||||||
|
r.Header.Set("X-Codex-Beta-Features", ginHeaders.Get("X-Codex-Beta-Features"))
|
||||||
|
}
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
|
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
|
||||||
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||||
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||||
|
|
||||||
|
if strings.Contains(r.Header.Get("User-Agent"), "Mac OS") {
|
||||||
|
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||||
|
}
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
r.Header.Set("Accept", "text/event-stream")
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
} else {
|
} else {
|
||||||
@@ -685,13 +731,47 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||||
err := statusErr{code: statusCode, msg: string(body)}
|
errCode := statusCode
|
||||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
if isCodexModelCapacityError(body) {
|
||||||
|
errCode = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
err := statusErr{code: errCode, msg: string(body)}
|
||||||
|
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
|
||||||
err.retryAfter = retryAfter
|
err.retryAfter = retryAfter
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeCodexInstructions(body []byte) []byte {
|
||||||
|
instructions := gjson.GetBytes(body, "instructions")
|
||||||
|
if !instructions.Exists() || instructions.Type == gjson.Null {
|
||||||
|
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCodexModelCapacityError(errorBody []byte) bool {
|
||||||
|
if len(errorBody) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
candidates := []string{
|
||||||
|
gjson.GetBytes(errorBody, "error.message").String(),
|
||||||
|
gjson.GetBytes(errorBody, "message").String(),
|
||||||
|
string(errorBody),
|
||||||
|
}
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(candidate))
|
||||||
|
if lower == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(lower, "selected model is at capacity") ||
|
||||||
|
strings.Contains(lower, "model is at capacity. please try a different model") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -42,8 +42,8 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
|||||||
if gotKey != expectedKey {
|
if gotKey != expectedKey {
|
||||||
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
||||||
}
|
}
|
||||||
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
|
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != "" {
|
||||||
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
|
t.Fatalf("Conversation_id = %q, want empty", gotConversation)
|
||||||
}
|
}
|
||||||
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
||||||
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
||||||
|
|||||||
79
internal/runtime/executor/codex_executor_compact_test.go
Normal file
79
internal/runtime/executor/codex_executor_compact_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexExecutorCompactAddsDefaultInstructions(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payload string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing instructions",
|
||||||
|
payload: `{"model":"gpt-5.4","input":"hello"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null instructions",
|
||||||
|
payload: `{"model":"gpt-5.4","instructions":null,"input":"hello"}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var gotPath string
|
||||||
|
var gotBody []byte
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
gotBody = body
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewCodexExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
"api_key": "test",
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Payload: []byte(tc.payload),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||||
|
Alt: "responses/compact",
|
||||||
|
Stream: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute error: %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/responses/compact" {
|
||||||
|
t.Fatalf("path = %q, want %q", gotPath, "/responses/compact")
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(gotBody, "instructions").Exists() {
|
||||||
|
t.Fatalf("expected instructions in compact request body, got %s", string(gotBody))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
|
||||||
|
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "instructions").String() != "" {
|
||||||
|
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
|
||||||
|
}
|
||||||
|
if string(resp.Payload) != `{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}` {
|
||||||
|
t.Fatalf("payload = %s", string(resp.Payload))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
123
internal/runtime/executor/codex_executor_instructions_test.go
Normal file
123
internal/runtime/executor/codex_executor_instructions_test.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexExecutorExecuteNormalizesNullInstructions(t *testing.T) {
|
||||||
|
var gotPath string
|
||||||
|
var gotBody []byte
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
gotBody = body
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewCodexExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
"api_key": "test",
|
||||||
|
}}
|
||||||
|
|
||||||
|
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||||
|
Stream: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute error: %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/responses" {
|
||||||
|
t.Fatalf("path = %q, want %q", gotPath, "/responses")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
|
||||||
|
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "instructions").String() != "" {
|
||||||
|
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexExecutorExecuteStreamNormalizesNullInstructions(t *testing.T) {
|
||||||
|
var gotPath string
|
||||||
|
var gotBody []byte
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
gotBody = body
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"background\":false,\"error\":null}}\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewCodexExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
"api_key": "test",
|
||||||
|
}}
|
||||||
|
|
||||||
|
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStream error: %v", err)
|
||||||
|
}
|
||||||
|
for range result.Chunks {
|
||||||
|
}
|
||||||
|
if gotPath != "/responses" {
|
||||||
|
t.Fatalf("path = %q, want %q", gotPath, "/responses")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "instructions").Type != gjson.String {
|
||||||
|
t.Fatalf("instructions type = %v, want string", gjson.GetBytes(gotBody, "instructions").Type)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "instructions").String() != "" {
|
||||||
|
t.Fatalf("instructions = %q, want empty string", gjson.GetBytes(gotBody, "instructions").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexExecutorCountTokensTreatsNullInstructionsAsEmpty(t *testing.T) {
|
||||||
|
executor := NewCodexExecutor(&config.Config{})
|
||||||
|
|
||||||
|
nullResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Payload: []byte(`{"model":"gpt-5.4","instructions":null,"input":"hello"}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountTokens(null) error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
emptyResp, err := executor.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-5.4",
|
||||||
|
Payload: []byte(`{"model":"gpt-5.4","instructions":"","input":"hello"}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountTokens(empty) error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(nullResp.Payload) != string(emptyResp.Payload) {
|
||||||
|
t.Fatalf("token count payload mismatch:\nnull=%s\nempty=%s", string(nullResp.Payload), string(emptyResp.Payload))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
|
||||||
|
|
||||||
|
err := newCodexStatusErr(http.StatusBadRequest, body)
|
||||||
|
|
||||||
|
if got := err.StatusCode(); got != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if err.RetryAfter() != nil {
|
||||||
|
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func itoa(v int64) string {
|
func itoa(v int64) string {
|
||||||
return strconv.FormatInt(v, 10)
|
return strconv.FormatInt(v, 10)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1775555723,\"status\":\"completed\",\"model\":\"gpt-5.4-mini-2026-03-17\",\"output\":[],\"usage\":{\"input_tokens\":8,\"output_tokens\":28,\"total_tokens\":36}}}\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewCodexExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
"api_key": "test",
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-5.4-mini",
|
||||||
|
Payload: []byte(`{"model":"gpt-5.4-mini","messages":[{"role":"user","content":"Say ok"}]}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
Stream: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotContent := gjson.GetBytes(resp.Payload, "choices.0.message.content").String()
|
||||||
|
if gotContent != "ok" {
|
||||||
|
t.Fatalf("choices.0.message.content = %q, want %q; payload=%s", gotContent, "ok", string(resp.Payload))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,10 +15,12 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -44,10 +46,18 @@ const (
|
|||||||
type CodexWebsocketsExecutor struct {
|
type CodexWebsocketsExecutor struct {
|
||||||
*CodexExecutor
|
*CodexExecutor
|
||||||
|
|
||||||
sessMu sync.Mutex
|
store *codexWebsocketSessionStore
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexWebsocketSessionStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
sessions map[string]*codexWebsocketSession
|
sessions map[string]*codexWebsocketSession
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{
|
||||||
|
sessions: make(map[string]*codexWebsocketSession),
|
||||||
|
}
|
||||||
|
|
||||||
type codexWebsocketSession struct {
|
type codexWebsocketSession struct {
|
||||||
sessionID string
|
sessionID string
|
||||||
|
|
||||||
@@ -71,7 +81,7 @@ type codexWebsocketSession struct {
|
|||||||
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
||||||
return &CodexWebsocketsExecutor{
|
return &CodexWebsocketsExecutor{
|
||||||
CodexExecutor: NewCodexExecutor(cfg),
|
CodexExecutor: NewCodexExecutor(cfg),
|
||||||
sessions: make(map[string]*codexWebsocketSession),
|
store: globalCodexWebsocketSessionStore,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,8 +165,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
@@ -173,8 +183,8 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
@@ -209,7 +219,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
|
|
||||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
wsReqLog := helps.UpstreamRequestLog{
|
||||||
URL: wsURL,
|
URL: wsURL,
|
||||||
Method: "WEBSOCKET",
|
Method: "WEBSOCKET",
|
||||||
Headers: wsHeaders.Clone(),
|
Headers: wsHeaders.Clone(),
|
||||||
@@ -219,16 +229,14 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
AuthLabel: authLabel,
|
AuthLabel: authLabel,
|
||||||
AuthType: authType,
|
AuthType: authType,
|
||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
}
|
||||||
|
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
|
||||||
|
|
||||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||||
if respHS != nil {
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
|
||||||
}
|
|
||||||
if errDial != nil {
|
if errDial != nil {
|
||||||
bodyErr := websocketHandshakeBody(respHS)
|
bodyErr := websocketHandshakeBody(respHS)
|
||||||
if len(bodyErr) > 0 {
|
if respHS != nil {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
|
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
|
||||||
}
|
}
|
||||||
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
||||||
return e.CodexExecutor.Execute(ctx, auth, req, opts)
|
return e.CodexExecutor.Execute(ctx, auth, req, opts)
|
||||||
@@ -236,10 +244,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
if respHS != nil && respHS.StatusCode > 0 {
|
if respHS != nil && respHS.StatusCode > 0 {
|
||||||
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
||||||
}
|
}
|
||||||
recordAPIResponseError(ctx, e.cfg, errDial)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
|
||||||
return resp, errDial
|
return resp, errDial
|
||||||
}
|
}
|
||||||
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
|
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -268,10 +276,10 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
// Retry once with a fresh websocket connection. This is mainly to handle
|
// Retry once with a fresh websocket connection. This is mainly to handle
|
||||||
// upstream closing the socket between sequential requests within the same
|
// upstream closing the socket between sequential requests within the same
|
||||||
// execution session.
|
// execution session.
|
||||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||||
if errDialRetry == nil && connRetry != nil {
|
if errDialRetry == nil && connRetry != nil {
|
||||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: wsURL,
|
URL: wsURL,
|
||||||
Method: "WEBSOCKET",
|
Method: "WEBSOCKET",
|
||||||
Headers: wsHeaders.Clone(),
|
Headers: wsHeaders.Clone(),
|
||||||
@@ -282,20 +290,22 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
AuthType: authType,
|
AuthType: authType,
|
||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
|
||||||
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
|
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
|
||||||
conn = connRetry
|
conn = connRetry
|
||||||
wsReqBody = wsReqBodyRetry
|
wsReqBody = wsReqBodyRetry
|
||||||
} else {
|
} else {
|
||||||
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
||||||
recordAPIResponseError(ctx, e.cfg, errSendRetry)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
|
||||||
return resp, errSendRetry
|
return resp, errSendRetry
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDialRetry)
|
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
|
||||||
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
|
||||||
return resp, errDialRetry
|
return resp, errDialRetry
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
recordAPIResponseError(ctx, e.cfg, errSend)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
|
||||||
return resp, errSend
|
return resp, errSend
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -306,7 +316,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
|
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
|
||||||
return resp, errRead
|
return resp, errRead
|
||||||
}
|
}
|
||||||
if msgType != websocket.TextMessage {
|
if msgType != websocket.TextMessage {
|
||||||
@@ -315,7 +325,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
if sess != nil {
|
if sess != nil {
|
||||||
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
||||||
}
|
}
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -325,21 +335,21 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, payload)
|
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
|
||||||
|
|
||||||
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
||||||
}
|
}
|
||||||
recordAPIResponseError(ctx, e.cfg, wsErr)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
|
||||||
return resp, wsErr
|
return resp, wsErr
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = normalizeCodexWebsocketCompletion(payload)
|
payload = normalizeCodexWebsocketCompletion(payload)
|
||||||
eventType := gjson.GetBytes(payload, "type").String()
|
eventType := gjson.GetBytes(payload, "type").String()
|
||||||
if eventType == "response.completed" {
|
if eventType == "response.completed" {
|
||||||
if detail, ok := parseCodexUsage(payload); ok {
|
if detail, ok := helps.ParseCodexUsage(payload); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
|
||||||
@@ -364,8 +374,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
@@ -376,8 +386,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
|
||||||
|
|
||||||
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
|
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
|
||||||
@@ -403,7 +413,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
}
|
}
|
||||||
|
|
||||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
wsReqLog := helps.UpstreamRequestLog{
|
||||||
URL: wsURL,
|
URL: wsURL,
|
||||||
Method: "WEBSOCKET",
|
Method: "WEBSOCKET",
|
||||||
Headers: wsHeaders.Clone(),
|
Headers: wsHeaders.Clone(),
|
||||||
@@ -413,18 +423,18 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
AuthLabel: authLabel,
|
AuthLabel: authLabel,
|
||||||
AuthType: authType,
|
AuthType: authType,
|
||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
}
|
||||||
|
helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog)
|
||||||
|
|
||||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||||
var upstreamHeaders http.Header
|
var upstreamHeaders http.Header
|
||||||
if respHS != nil {
|
if respHS != nil {
|
||||||
upstreamHeaders = respHS.Header.Clone()
|
upstreamHeaders = respHS.Header.Clone()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
|
||||||
}
|
}
|
||||||
if errDial != nil {
|
if errDial != nil {
|
||||||
bodyErr := websocketHandshakeBody(respHS)
|
bodyErr := websocketHandshakeBody(respHS)
|
||||||
if len(bodyErr) > 0 {
|
if respHS != nil {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
|
helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr)
|
||||||
}
|
}
|
||||||
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
||||||
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
|
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
|
||||||
@@ -432,13 +442,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
if respHS != nil && respHS.StatusCode > 0 {
|
if respHS != nil && respHS.StatusCode > 0 {
|
||||||
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
||||||
}
|
}
|
||||||
recordAPIResponseError(ctx, e.cfg, errDial)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial)
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
sess.reqMu.Unlock()
|
sess.reqMu.Unlock()
|
||||||
}
|
}
|
||||||
return nil, errDial
|
return nil, errDial
|
||||||
}
|
}
|
||||||
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
|
recordAPIWebsocketHandshake(ctx, e.cfg, respHS)
|
||||||
|
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
||||||
@@ -451,20 +461,21 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
}
|
}
|
||||||
|
|
||||||
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
|
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errSend)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend)
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
|
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
|
||||||
|
|
||||||
// Retry once with a new websocket connection for the same execution session.
|
// Retry once with a new websocket connection for the same execution session.
|
||||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||||
if errDialRetry != nil || connRetry == nil {
|
if errDialRetry != nil || connRetry == nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDialRetry)
|
closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error")
|
||||||
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry)
|
||||||
sess.clearActive(readCh)
|
sess.clearActive(readCh)
|
||||||
sess.reqMu.Unlock()
|
sess.reqMu.Unlock()
|
||||||
return nil, errDialRetry
|
return nil, errDialRetry
|
||||||
}
|
}
|
||||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: wsURL,
|
URL: wsURL,
|
||||||
Method: "WEBSOCKET",
|
Method: "WEBSOCKET",
|
||||||
Headers: wsHeaders.Clone(),
|
Headers: wsHeaders.Clone(),
|
||||||
@@ -475,8 +486,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
AuthType: authType,
|
AuthType: authType,
|
||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry)
|
||||||
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
|
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errSendRetry)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry)
|
||||||
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
||||||
sess.clearActive(readCh)
|
sess.clearActive(readCh)
|
||||||
sess.reqMu.Unlock()
|
sess.reqMu.Unlock()
|
||||||
@@ -542,8 +554,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
}
|
}
|
||||||
terminateReason = "read_error"
|
terminateReason = "read_error"
|
||||||
terminateErr = errRead
|
terminateErr = errRead
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
|
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -552,8 +564,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
err = fmt.Errorf("codex websockets executor: unexpected binary message")
|
err = fmt.Errorf("codex websockets executor: unexpected binary message")
|
||||||
terminateReason = "unexpected_binary"
|
terminateReason = "unexpected_binary"
|
||||||
terminateErr = err
|
terminateErr = err
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
||||||
}
|
}
|
||||||
@@ -567,13 +579,13 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, payload)
|
helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload)
|
||||||
|
|
||||||
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
||||||
terminateReason = "upstream_error"
|
terminateReason = "upstream_error"
|
||||||
terminateErr = wsErr
|
terminateErr = wsErr
|
||||||
recordAPIResponseError(ctx, e.cfg, wsErr)
|
helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
||||||
}
|
}
|
||||||
@@ -584,8 +596,8 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
payload = normalizeCodexWebsocketCompletion(payload)
|
payload = normalizeCodexWebsocketCompletion(payload)
|
||||||
eventType := gjson.GetBytes(payload, "type").String()
|
eventType := gjson.GetBytes(payload, "type").String()
|
||||||
if eventType == "response.completed" || eventType == "response.done" {
|
if eventType == "response.completed" || eventType == "response.done" {
|
||||||
if detail, ok := parseCodexUsage(payload); ok {
|
if detail, ok := helps.ParseCodexUsage(payload); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -722,7 +734,7 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch setting.URL.Scheme {
|
switch setting.URL.Scheme {
|
||||||
case "socks5":
|
case "socks5", "socks5h":
|
||||||
var proxyAuth *proxy.Auth
|
var proxyAuth *proxy.Auth
|
||||||
if setting.URL.User != nil {
|
if setting.URL.User != nil {
|
||||||
username := setting.URL.User.Username()
|
username := setting.URL.User.Username()
|
||||||
@@ -767,19 +779,19 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
|||||||
return rawJSON, headers
|
return rawJSON, headers
|
||||||
}
|
}
|
||||||
|
|
||||||
var cache codexCache
|
var cache helps.CodexCache
|
||||||
if from == "claude" {
|
if from == "claude" {
|
||||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||||
if userIDResult.Exists() {
|
if userIDResult.Exists() {
|
||||||
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
||||||
if cached, ok := getCodexCache(key); ok {
|
if cached, ok := helps.GetCodexCache(key); ok {
|
||||||
cache = cached
|
cache = cached
|
||||||
} else {
|
} else {
|
||||||
cache = codexCache{
|
cache = helps.CodexCache{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Expire: time.Now().Add(1 * time.Hour),
|
Expire: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
setCodexCache(key, cache)
|
helps.SetCodexCache(key, cache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if from == "openai-response" {
|
} else if from == "openai-response" {
|
||||||
@@ -791,7 +803,6 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
|||||||
if cache.ID != "" {
|
if cache.ID != "" {
|
||||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||||
headers.Set("Conversation_id", cache.ID)
|
headers.Set("Conversation_id", cache.ID)
|
||||||
headers.Set("Session_id", cache.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawJSON, headers
|
return rawJSON, headers
|
||||||
@@ -806,11 +817,11 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ginHeaders http.Header
|
var ginHeaders http.Header
|
||||||
if ginCtx := ginContextFrom(ctx); ginCtx != nil && ginCtx.Request != nil {
|
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||||
ginHeaders = ginCtx.Request.Header
|
ginHeaders = ginCtx.Request.Header.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
_, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
||||||
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
||||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
||||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
||||||
@@ -826,8 +837,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
|||||||
betaHeader = codexResponsesWebsocketBetaHeaderValue
|
betaHeader = codexResponsesWebsocketBetaHeaderValue
|
||||||
}
|
}
|
||||||
headers.Set("OpenAI-Beta", betaHeader)
|
headers.Set("OpenAI-Beta", betaHeader)
|
||||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
|
||||||
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||||
|
}
|
||||||
|
headers.Del("User-Agent")
|
||||||
|
|
||||||
isAPIKey := false
|
isAPIKey := false
|
||||||
if auth != nil && auth.Attributes != nil {
|
if auth != nil && auth.Attributes != nil {
|
||||||
@@ -1011,6 +1024,32 @@ func encodeCodexWebsocketAsSSE(payload []byte) []byte {
|
|||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog {
|
||||||
|
upgradeInfo := info
|
||||||
|
upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL)
|
||||||
|
upgradeInfo.Method = http.MethodGet
|
||||||
|
upgradeInfo.Body = nil
|
||||||
|
upgradeInfo.Headers = info.Headers.Clone()
|
||||||
|
if upgradeInfo.Headers == nil {
|
||||||
|
upgradeInfo.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" {
|
||||||
|
upgradeInfo.Headers.Set("Connection", "Upgrade")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" {
|
||||||
|
upgradeInfo.Headers.Set("Upgrade", "websocket")
|
||||||
|
}
|
||||||
|
return upgradeInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) {
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone())
|
||||||
|
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
|
||||||
|
}
|
||||||
|
|
||||||
func websocketHandshakeBody(resp *http.Response) []byte {
|
func websocketHandshakeBody(resp *http.Response) []byte {
|
||||||
if resp == nil || resp.Body == nil {
|
if resp == nil || resp.Body == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -1055,16 +1094,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
|
|||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
e.sessMu.Lock()
|
if e == nil {
|
||||||
defer e.sessMu.Unlock()
|
return nil
|
||||||
if e.sessions == nil {
|
|
||||||
e.sessions = make(map[string]*codexWebsocketSession)
|
|
||||||
}
|
}
|
||||||
if sess, ok := e.sessions[sessionID]; ok && sess != nil {
|
store := e.store
|
||||||
|
if store == nil {
|
||||||
|
store = globalCodexWebsocketSessionStore
|
||||||
|
}
|
||||||
|
store.mu.Lock()
|
||||||
|
defer store.mu.Unlock()
|
||||||
|
if store.sessions == nil {
|
||||||
|
store.sessions = make(map[string]*codexWebsocketSession)
|
||||||
|
}
|
||||||
|
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
sess := &codexWebsocketSession{sessionID: sessionID}
|
sess := &codexWebsocketSession{sessionID: sessionID}
|
||||||
e.sessions[sessionID] = sess
|
store.sessions[sessionID] = sess
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1210,14 +1256,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
|
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
|
||||||
e.closeAllExecutionSessions("executor_replaced")
|
// Executor replacement can happen during hot reload (config/credential changes).
|
||||||
|
// Do not force-close upstream websocket sessions here, otherwise in-flight
|
||||||
|
// downstream websocket requests get interrupted.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
e.sessMu.Lock()
|
store := e.store
|
||||||
sess := e.sessions[sessionID]
|
if store == nil {
|
||||||
delete(e.sessions, sessionID)
|
store = globalCodexWebsocketSessionStore
|
||||||
e.sessMu.Unlock()
|
}
|
||||||
|
store.mu.Lock()
|
||||||
|
sess := store.sessions[sessionID]
|
||||||
|
delete(store.sessions, sessionID)
|
||||||
|
store.mu.Unlock()
|
||||||
|
|
||||||
e.closeExecutionSession(sess, "session_closed")
|
e.closeExecutionSession(sess, "session_closed")
|
||||||
}
|
}
|
||||||
@@ -1227,15 +1279,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
e.sessMu.Lock()
|
store := e.store
|
||||||
sessions := make([]*codexWebsocketSession, 0, len(e.sessions))
|
if store == nil {
|
||||||
for sessionID, sess := range e.sessions {
|
store = globalCodexWebsocketSessionStore
|
||||||
delete(e.sessions, sessionID)
|
}
|
||||||
|
store.mu.Lock()
|
||||||
|
sessions := make([]*codexWebsocketSession, 0, len(store.sessions))
|
||||||
|
for sessionID, sess := range store.sessions {
|
||||||
|
delete(store.sessions, sessionID)
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
sessions = append(sessions, sess)
|
sessions = append(sessions, sess)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
e.sessMu.Unlock()
|
store.mu.Unlock()
|
||||||
|
|
||||||
for i := range sessions {
|
for i := range sessions {
|
||||||
e.closeExecutionSession(sessions[i], reason)
|
e.closeExecutionSession(sessions[i], reason)
|
||||||
@@ -1243,6 +1299,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
|
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
|
||||||
|
closeCodexWebsocketSession(sess, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) {
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1283,6 +1343,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string
|
|||||||
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
|
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions
|
||||||
|
// associated with the supplied auth ID.
|
||||||
|
func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) {
|
||||||
|
authID = strings.TrimSpace(authID)
|
||||||
|
if authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reason = strings.TrimSpace(reason)
|
||||||
|
if reason == "" {
|
||||||
|
reason = "auth_removed"
|
||||||
|
}
|
||||||
|
|
||||||
|
store := globalCodexWebsocketSessionStore
|
||||||
|
if store == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionItem struct {
|
||||||
|
sessionID string
|
||||||
|
sess *codexWebsocketSession
|
||||||
|
}
|
||||||
|
|
||||||
|
store.mu.Lock()
|
||||||
|
items := make([]sessionItem, 0, len(store.sessions))
|
||||||
|
for sessionID, sess := range store.sessions {
|
||||||
|
items = append(items, sessionItem{sessionID: sessionID, sess: sess})
|
||||||
|
}
|
||||||
|
store.mu.Unlock()
|
||||||
|
|
||||||
|
matches := make([]sessionItem, 0)
|
||||||
|
for i := range items {
|
||||||
|
sess := items[i].sess
|
||||||
|
if sess == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sess.connMu.Lock()
|
||||||
|
sessAuthID := strings.TrimSpace(sess.authID)
|
||||||
|
sess.connMu.Unlock()
|
||||||
|
if sessAuthID == authID {
|
||||||
|
matches = append(matches, items[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
toClose := make([]*codexWebsocketSession, 0, len(matches))
|
||||||
|
store.mu.Lock()
|
||||||
|
for i := range matches {
|
||||||
|
current, ok := store.sessions[matches[i].sessionID]
|
||||||
|
if !ok || current == nil || current != matches[i].sess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(store.sessions, matches[i].sessionID)
|
||||||
|
toClose = append(toClose, current)
|
||||||
|
}
|
||||||
|
store.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range toClose {
|
||||||
|
closeCodexWebsocketSession(toClose[i], reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
|
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
|
||||||
// 1. The downstream transport is websocket, and
|
// 1. The downstream transport is websocket, and
|
||||||
// 2. The selected auth enables websockets.
|
// 2. The selected auth enables websockets.
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) {
|
||||||
|
sessionID := "test-session-store-survives-replace"
|
||||||
|
|
||||||
|
globalCodexWebsocketSessionStore.mu.Lock()
|
||||||
|
delete(globalCodexWebsocketSessionStore.sessions, sessionID)
|
||||||
|
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||||
|
|
||||||
|
exec1 := NewCodexWebsocketsExecutor(nil)
|
||||||
|
sess1 := exec1.getOrCreateSession(sessionID)
|
||||||
|
if sess1 == nil {
|
||||||
|
t.Fatalf("expected session to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
exec2 := NewCodexWebsocketsExecutor(nil)
|
||||||
|
sess2 := exec2.getOrCreateSession(sessionID)
|
||||||
|
if sess2 == nil {
|
||||||
|
t.Fatalf("expected session to be available across executors")
|
||||||
|
}
|
||||||
|
if sess1 != sess2 {
|
||||||
|
t.Fatalf("expected the same session instance across executors")
|
||||||
|
}
|
||||||
|
|
||||||
|
exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID)
|
||||||
|
|
||||||
|
globalCodexWebsocketSessionStore.mu.Lock()
|
||||||
|
_, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID]
|
||||||
|
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||||
|
if !stillPresent {
|
||||||
|
t.Fatalf("expected session to remain after executor replacement close marker")
|
||||||
|
}
|
||||||
|
|
||||||
|
exec2.CloseExecutionSession(sessionID)
|
||||||
|
|
||||||
|
globalCodexWebsocketSessionStore.mu.Lock()
|
||||||
|
_, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID]
|
||||||
|
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||||
|
if presentAfterClose {
|
||||||
|
t.Fatalf("expected session to be removed after explicit close")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -38,8 +38,8 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T)
|
|||||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||||
}
|
}
|
||||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
if got := headers.Get("User-Agent"); got != "" {
|
||||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
t.Fatalf("User-Agent = %s, want empty", got)
|
||||||
}
|
}
|
||||||
if got := headers.Get("Version"); got != "" {
|
if got := headers.Get("Version"); got != "" {
|
||||||
t.Fatalf("Version = %q, want empty", got)
|
t.Fatalf("Version = %q, want empty", got)
|
||||||
@@ -97,8 +97,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
|||||||
|
|
||||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
|
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
|
||||||
|
|
||||||
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
|
if got := headers.Get("User-Agent"); got != "" {
|
||||||
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
|
t.Fatalf("User-Agent = %s, want empty", got)
|
||||||
}
|
}
|
||||||
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
|
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
|
||||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
|
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
|
||||||
@@ -129,8 +129,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
|
|||||||
|
|
||||||
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
|
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
|
||||||
|
|
||||||
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
|
if gotVal := got.Get("User-Agent"); gotVal != "" {
|
||||||
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
|
t.Fatalf("User-Agent = %s, want empty", gotVal)
|
||||||
}
|
}
|
||||||
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
|
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
|
||||||
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
|
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
|
||||||
@@ -155,8 +155,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
|
|||||||
|
|
||||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
|
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
|
||||||
|
|
||||||
if got := headers.Get("User-Agent"); got != "config-ua" {
|
if got := headers.Get("User-Agent"); got != "" {
|
||||||
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
t.Fatalf("User-Agent = %s, want empty", got)
|
||||||
}
|
}
|
||||||
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
|
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
|
||||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
|
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
|
||||||
@@ -177,8 +177,8 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
|
|||||||
|
|
||||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
|
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
|
||||||
|
|
||||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
if got := headers.Get("User-Agent"); got != "" {
|
||||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
t.Fatalf("User-Agent = %s, want empty", got)
|
||||||
}
|
}
|
||||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||||
|
|||||||
129
internal/runtime/executor/compat_helpers.go
Normal file
129
internal/runtime/executor/compat_helpers.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tiktoken-go/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||||
|
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpenAIUsage(data []byte) usage.Detail {
|
||||||
|
return helps.ParseOpenAIUsage(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
|
return helps.ParseOpenAIStreamUsage(line)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
|
||||||
|
return helps.ParseOpenAIUsage(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
|
return helps.ParseOpenAIStreamUsage(line)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTokenizer(model string) (tokenizer.Codec, error) {
|
||||||
|
return helps.TokenizerForModel(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||||
|
return helps.CountOpenAIChatTokens(enc, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func countClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||||
|
return helps.CountClaudeChatTokens(enc, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIUsageJSON(count int64) []byte {
|
||||||
|
return helps.BuildOpenAIUsageJSON(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamRequestLog = helps.UpstreamRequestLog
|
||||||
|
|
||||||
|
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
|
||||||
|
helps.RecordAPIRequest(ctx, cfg, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, cfg, status, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
|
||||||
|
helps.RecordAPIResponseError(ctx, cfg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
|
||||||
|
helps.AppendAPIResponseChunk(ctx, cfg, chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||||
|
return helps.PayloadRequestedModel(opts, fallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||||
|
return helps.ApplyPayloadConfigWithRoot(cfg, model, protocol, root, payload, original, requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeErrorBody(contentType string, body []byte) string {
|
||||||
|
return helps.SummarizeErrorBody(contentType, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func apiKeyFromContext(ctx context.Context) string {
|
||||||
|
return helps.APIKeyFromContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||||
|
return helps.TokenizerForModel(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||||
|
helps.CollectOpenAIContent(content, segments)
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageReporter struct {
|
||||||
|
reporter *helps.UsageReporter
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
|
||||||
|
return &usageReporter{reporter: helps.NewUsageReporter(ctx, provider, model, auth)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
|
||||||
|
if r == nil || r.reporter == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.reporter.Publish(ctx, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageReporter) publishFailure(ctx context.Context) {
|
||||||
|
if r == nil || r.reporter == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.reporter.PublishFailure(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
|
||||||
|
if r == nil || r.reporter == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.reporter.TrackFailure(ctx, errPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageReporter) ensurePublished(ctx context.Context) {
|
||||||
|
if r == nil || r.reporter == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.reporter.EnsurePublished(ctx)
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
@@ -81,6 +82,11 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
|
|||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
applyGeminiCLIHeaders(req, "unknown")
|
applyGeminiCLIHeaders(req, "unknown")
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,8 +118,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
@@ -132,8 +138,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -190,7 +196,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||||
reqHTTP.Header.Set("Accept", "application/json")
|
reqHTTP.Header.Set("Accept", "application/json")
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||||
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: reqHTTP.Header.Clone(),
|
Headers: reqHTTP.Header.Clone(),
|
||||||
@@ -204,7 +211,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
httpResp, errDo := httpClient.Do(reqHTTP)
|
httpResp, errDo := httpClient.Do(reqHTTP)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
err = errDo
|
err = errDo
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -213,15 +220,15 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
err = errRead
|
err = errRead
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
|
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
|
||||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
@@ -230,7 +237,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
if httpResp.StatusCode == 429 {
|
if httpResp.StatusCode == 429 {
|
||||||
if idx+1 < len(models) {
|
if idx+1 < len(models) {
|
||||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||||
@@ -245,7 +252,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(lastBody) > 0 {
|
if len(lastBody) > 0 {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, lastBody)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
|
||||||
}
|
}
|
||||||
if lastStatus == 0 {
|
if lastStatus == 0 {
|
||||||
lastStatus = 429
|
lastStatus = 429
|
||||||
@@ -266,8 +273,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
@@ -286,8 +293,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
|
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||||
|
|
||||||
projectID := resolveGeminiProjectID(auth)
|
projectID := resolveGeminiProjectID(auth)
|
||||||
|
|
||||||
@@ -335,7 +342,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
applyGeminiCLIHeaders(reqHTTP, attemptModel)
|
||||||
reqHTTP.Header.Set("Accept", "text/event-stream")
|
reqHTTP.Header.Set("Accept", "text/event-stream")
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||||
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: reqHTTP.Header.Clone(),
|
Headers: reqHTTP.Header.Clone(),
|
||||||
@@ -349,25 +357,25 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
|
|
||||||
httpResp, errDo := httpClient.Do(reqHTTP)
|
httpResp, errDo := httpClient.Do(reqHTTP)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
err = errDo
|
err = errDo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
err = errRead
|
err = errRead
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
if httpResp.StatusCode == 429 {
|
if httpResp.StatusCode == 429 {
|
||||||
if idx+1 < len(models) {
|
if idx+1 < len(models) {
|
||||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||||
@@ -394,9 +402,9 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
if detail, ok := parseGeminiCLIStreamUsage(line); ok {
|
if detail, ok := helps.ParseGeminiCLIStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
if bytes.HasPrefix(line, dataTag) {
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||||
@@ -411,8 +419,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -420,13 +428,13 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
|
|
||||||
data, errRead := io.ReadAll(resp.Body)
|
data, errRead := io.ReadAll(resp.Body)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errRead}
|
out <- cliproxyexecutor.StreamChunk{Err: errRead}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data))
|
||||||
var param any
|
var param any
|
||||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||||
for i := range segments {
|
for i := range segments {
|
||||||
@@ -443,7 +451,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(lastBody) > 0 {
|
if len(lastBody) > 0 {
|
||||||
appendAPIResponseChunk(ctx, e.cfg, lastBody)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody)
|
||||||
}
|
}
|
||||||
if lastStatus == 0 {
|
if lastStatus == 0 {
|
||||||
lastStatus = 429
|
lastStatus = 429
|
||||||
@@ -516,7 +524,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
applyGeminiCLIHeaders(reqHTTP, baseModel)
|
applyGeminiCLIHeaders(reqHTTP, baseModel)
|
||||||
reqHTTP.Header.Set("Accept", "application/json")
|
reqHTTP.Header.Set("Accept", "application/json")
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes)
|
||||||
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: reqHTTP.Header.Clone(),
|
Headers: reqHTTP.Header.Clone(),
|
||||||
@@ -530,17 +539,19 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
|
|
||||||
resp, errDo := httpClient.Do(reqHTTP)
|
resp, errDo := httpClient.Do(reqHTTP)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
return cliproxyexecutor.Response{}, errDo
|
return cliproxyexecutor.Response{}, errDo
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(resp.Body)
|
data, errRead := io.ReadAll(resp.Body)
|
||||||
_ = resp.Body.Close()
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
|
||||||
|
}
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
return cliproxyexecutor.Response{}, errRead
|
return cliproxyexecutor.Response{}, errRead
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||||
@@ -611,7 +622,7 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctxToken := ctx
|
ctxToken := ctx
|
||||||
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
||||||
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
|
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -707,7 +718,7 @@ func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||||
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func cloneMap(in map[string]any) map[string]any {
|
func cloneMap(in map[string]any) map[string]any {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -85,7 +86,7 @@ func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,8 +111,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
|
|
||||||
apiKey, bearer := geminiCreds(auth)
|
apiKey, bearer := geminiCreds(auth)
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
// Official Gemini API via API key or OAuth bearer
|
// Official Gemini API via API key or OAuth bearer
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
@@ -130,8 +131,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
@@ -165,7 +166,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -177,10 +178,10 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -188,21 +189,21 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
log.Errorf("gemini executor: close response body error: %v", errClose)
|
log.Errorf("gemini executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
@@ -218,8 +219,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
|
|
||||||
apiKey, bearer := geminiCreds(auth)
|
apiKey, bearer := geminiCreds(auth)
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
@@ -237,8 +238,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
@@ -268,7 +269,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -280,17 +281,17 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("gemini executor: close response body error: %v", errClose)
|
log.Errorf("gemini executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -310,14 +311,14 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
filtered := FilterSSEUsageMetadata(line)
|
filtered := helps.FilterSSEUsageMetadata(line)
|
||||||
payload := jsonPayload(filtered)
|
payload := helps.JSONPayload(filtered)
|
||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if detail, ok := parseGeminiStreamUsage(payload); ok {
|
if detail, ok := helps.ParseGeminiStreamUsage(payload); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
@@ -329,8 +330,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -381,7 +382,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -393,23 +394,27 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
resp, err := httpClient.Do(httpReq)
|
resp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() {
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
data, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, helps.SummarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,9 @@ import (
|
|||||||
|
|
||||||
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -227,7 +229,7 @@ func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
|
|||||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,8 +303,8 @@ func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Aut
|
|||||||
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
var body []byte
|
var body []byte
|
||||||
|
|
||||||
@@ -332,8 +334,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,6 +364,11 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
return resp, statusErr{code: 500, msg: "internal server error"}
|
return resp, statusErr{code: 500, msg: "internal server error"}
|
||||||
}
|
}
|
||||||
applyGeminiHeaders(httpReq, auth)
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -369,7 +376,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -381,10 +388,10 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
return resp, errDo
|
return resp, errDo
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -392,21 +399,21 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
return resp, errRead
|
return resp, errRead
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
|
||||||
|
|
||||||
// For Imagen models, convert response to Gemini format before translation
|
// For Imagen models, convert response to Gemini format before translation
|
||||||
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
|
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
|
||||||
@@ -427,8 +434,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
@@ -447,8 +454,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := getVertexAction(baseModel, false)
|
action := getVertexAction(baseModel, false)
|
||||||
@@ -477,6 +484,11 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||||
}
|
}
|
||||||
applyGeminiHeaders(httpReq, auth)
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -484,7 +496,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -496,10 +508,10 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
return resp, errDo
|
return resp, errDo
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -507,21 +519,21 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
return resp, errRead
|
return resp, errRead
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.Publish(ctx, helps.ParseGeminiUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
|
||||||
@@ -532,8 +544,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
@@ -552,8 +564,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := getVertexAction(baseModel, true)
|
action := getVertexAction(baseModel, true)
|
||||||
@@ -581,6 +593,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
return nil, statusErr{code: 500, msg: "internal server error"}
|
return nil, statusErr{code: 500, msg: "internal server error"}
|
||||||
}
|
}
|
||||||
applyGeminiHeaders(httpReq, auth)
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -588,7 +605,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -600,17 +617,17 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
return nil, errDo
|
return nil, errDo
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -630,9 +647,9 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
@@ -644,8 +661,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -656,8 +673,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
@@ -676,8 +693,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := getVertexAction(baseModel, true)
|
action := getVertexAction(baseModel, true)
|
||||||
@@ -705,6 +722,11 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||||
}
|
}
|
||||||
applyGeminiHeaders(httpReq, auth)
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -712,7 +734,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -724,17 +746,17 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
return nil, errDo
|
return nil, errDo
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -754,9 +776,9 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
if detail, ok := helps.ParseGeminiStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range lines {
|
for i := range lines {
|
||||||
@@ -768,8 +790,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -812,6 +834,11 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||||
}
|
}
|
||||||
applyGeminiHeaders(httpReq, auth)
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -819,7 +846,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -831,10 +858,10 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
return cliproxyexecutor.Response{}, errDo
|
return cliproxyexecutor.Response{}, errDo
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -842,19 +869,19 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
return cliproxyexecutor.Response{}, errRead
|
return cliproxyexecutor.Response{}, errRead
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||||
@@ -896,6 +923,11 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||||
}
|
}
|
||||||
applyGeminiHeaders(httpReq, auth)
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -903,7 +935,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -915,10 +947,10 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
return cliproxyexecutor.Response{}, errDo
|
return cliproxyexecutor.Response{}, errDo
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -926,19 +958,19 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
return cliproxyexecutor.Response{}, errRead
|
return cliproxyexecutor.Response{}, errRead
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||||
@@ -1012,7 +1044,7 @@ func vertexBaseURL(location string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
|
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
|
||||||
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||||
}
|
}
|
||||||
// Use cloud-platform scope for Vertex AI.
|
// Use cloud-platform scope for Vertex AI.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -17,6 +18,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -40,7 +42,7 @@ const (
|
|||||||
copilotEditorVersion = "vscode/1.107.0"
|
copilotEditorVersion = "vscode/1.107.0"
|
||||||
copilotPluginVersion = "copilot-chat/0.35.0"
|
copilotPluginVersion = "copilot-chat/0.35.0"
|
||||||
copilotIntegrationID = "vscode-chat"
|
copilotIntegrationID = "vscode-chat"
|
||||||
copilotOpenAIIntent = "conversation-panel"
|
copilotOpenAIIntent = "conversation-edits"
|
||||||
copilotGitHubAPIVer = "2025-04-01"
|
copilotGitHubAPIVer = "2025-04-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -126,6 +128,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
body = e.normalizeModel(req.Model, body)
|
body = e.normalizeModel(req.Model, body)
|
||||||
body = flattenAssistantContent(body)
|
body = flattenAssistantContent(body)
|
||||||
|
body = stripUnsupportedBetas(body)
|
||||||
|
|
||||||
// Detect vision content before input normalization removes messages
|
// Detect vision content before input normalization removes messages
|
||||||
hasVision := detectVisionContent(body)
|
hasVision := detectVisionContent(body)
|
||||||
@@ -142,6 +145,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
if useResponses {
|
if useResponses {
|
||||||
body = normalizeGitHubCopilotResponsesInput(body)
|
body = normalizeGitHubCopilotResponsesInput(body)
|
||||||
body = normalizeGitHubCopilotResponsesTools(body)
|
body = normalizeGitHubCopilotResponsesTools(body)
|
||||||
|
body = applyGitHubCopilotResponsesDefaults(body)
|
||||||
} else {
|
} else {
|
||||||
body = normalizeGitHubCopilotChatTools(body)
|
body = normalizeGitHubCopilotChatTools(body)
|
||||||
}
|
}
|
||||||
@@ -225,9 +229,10 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
if useResponses && from.String() == "claude" {
|
if useResponses && from.String() == "claude" {
|
||||||
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
||||||
} else {
|
} else {
|
||||||
|
data = normalizeGitHubCopilotReasoningField(data)
|
||||||
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
}
|
}
|
||||||
resp = cliproxyexecutor.Response{Payload: converted}
|
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -256,6 +261,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
body = e.normalizeModel(req.Model, body)
|
body = e.normalizeModel(req.Model, body)
|
||||||
body = flattenAssistantContent(body)
|
body = flattenAssistantContent(body)
|
||||||
|
body = stripUnsupportedBetas(body)
|
||||||
|
|
||||||
// Detect vision content before input normalization removes messages
|
// Detect vision content before input normalization removes messages
|
||||||
hasVision := detectVisionContent(body)
|
hasVision := detectVisionContent(body)
|
||||||
@@ -272,6 +278,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
if useResponses {
|
if useResponses {
|
||||||
body = normalizeGitHubCopilotResponsesInput(body)
|
body = normalizeGitHubCopilotResponsesInput(body)
|
||||||
body = normalizeGitHubCopilotResponsesTools(body)
|
body = normalizeGitHubCopilotResponsesTools(body)
|
||||||
|
body = applyGitHubCopilotResponsesDefaults(body)
|
||||||
} else {
|
} else {
|
||||||
body = normalizeGitHubCopilotChatTools(body)
|
body = normalizeGitHubCopilotChatTools(body)
|
||||||
}
|
}
|
||||||
@@ -378,7 +385,20 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
if useResponses && from.String() == "claude" {
|
if useResponses && from.String() == "claude" {
|
||||||
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
||||||
} else {
|
} else {
|
||||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
// Strip SSE "data: " prefix before reasoning field normalization,
|
||||||
|
// since normalizeGitHubCopilotReasoningField expects pure JSON.
|
||||||
|
// Re-wrap with the prefix afterward for the translator.
|
||||||
|
normalizedLine := bytes.Clone(line)
|
||||||
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
|
sseData := bytes.TrimSpace(line[len(dataTag):])
|
||||||
|
if !bytes.Equal(sseData, []byte("[DONE]")) && gjson.ValidBytes(sseData) {
|
||||||
|
normalized := normalizeGitHubCopilotReasoningField(bytes.Clone(sseData))
|
||||||
|
if !bytes.Equal(normalized, sseData) {
|
||||||
|
normalizedLine = append(append([]byte(nil), dataTag...), normalized...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, normalizedLine, ¶m)
|
||||||
}
|
}
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
|
||||||
@@ -400,9 +420,28 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens is not supported for GitHub Copilot.
|
// CountTokens estimates token count locally using tiktoken, since the GitHub
|
||||||
func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
// Copilot API does not expose a dedicated token counting endpoint.
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"}
|
func (e *GitHubCopilotExecutor) CountTokens(ctx context.Context, _ *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("openai")
|
||||||
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||||
|
|
||||||
|
enc, err := helps.TokenizerForModel(baseModel)
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, fmt.Errorf("github copilot executor: tokenizer init failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := helps.CountOpenAIChatTokens(enc, translated)
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, fmt.Errorf("github copilot executor: token counting failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||||
|
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||||
|
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh validates the GitHub token is still working.
|
// Refresh validates the GitHub token is still working.
|
||||||
@@ -491,46 +530,127 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
|
|||||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||||
|
|
||||||
initiator := "user"
|
initiator := "user"
|
||||||
if role := detectLastConversationRole(body); role == "assistant" || role == "tool" {
|
if isAgentInitiated(body) {
|
||||||
initiator = "agent"
|
initiator = "agent"
|
||||||
}
|
}
|
||||||
r.Header.Set("X-Initiator", initiator)
|
r.Header.Set("X-Initiator", initiator)
|
||||||
}
|
}
|
||||||
|
|
||||||
func detectLastConversationRole(body []byte) string {
|
// isAgentInitiated determines whether the current request is agent-initiated
|
||||||
|
// (tool callbacks, continuations) rather than user-initiated (new user prompt).
|
||||||
|
//
|
||||||
|
// GitHub Copilot uses the X-Initiator header for billing:
|
||||||
|
// - "user" → consumes premium request quota
|
||||||
|
// - "agent" → free (tool loops, continuations)
|
||||||
|
//
|
||||||
|
// The challenge: Claude Code sends tool results as role:"user" messages with
|
||||||
|
// content type "tool_result". After translation to OpenAI format, the tool_result
|
||||||
|
// part becomes a separate role:"tool" message, but if the original Claude message
|
||||||
|
// also contained text content (e.g. skill invocations, attachment descriptions),
|
||||||
|
// a role:"user" message is emitted AFTER the tool message, making the last message
|
||||||
|
// appear user-initiated when it's actually part of an agent tool loop.
|
||||||
|
//
|
||||||
|
// VSCode Copilot Chat solves this with explicit flags (iterationNumber,
|
||||||
|
// isContinuation, subAgentInvocationId). Since CPA doesn't have these flags,
|
||||||
|
// we infer agent status by checking whether the conversation contains prior
|
||||||
|
// assistant/tool messages — if it does, the current request is a continuation.
|
||||||
|
//
|
||||||
|
// References:
|
||||||
|
// - opencode#8030, opencode#15824: same root cause and fix approach
|
||||||
|
// - vscode-copilot-chat: toolCallingLoop.ts (iterationNumber === 0)
|
||||||
|
// - pi-ai: github-copilot-headers.ts (last message role check)
|
||||||
|
func isAgentInitiated(body []byte) bool {
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return ""
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Chat Completions API: check messages array
|
||||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||||
arr := messages.Array()
|
arr := messages.Array()
|
||||||
|
if len(arr) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lastRole := ""
|
||||||
for i := len(arr) - 1; i >= 0; i-- {
|
for i := len(arr) - 1; i >= 0; i-- {
|
||||||
if role := arr[i].Get("role").String(); role != "" {
|
if r := arr[i].Get("role").String(); r != "" {
|
||||||
return role
|
lastRole = r
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If last message is assistant or tool, clearly agent-initiated.
|
||||||
|
if lastRole == "assistant" || lastRole == "tool" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If last message is "user", check whether it contains tool results
|
||||||
|
// (indicating a tool-loop continuation) or if the preceding message
|
||||||
|
// is an assistant tool_use. This is more precise than checking for
|
||||||
|
// any prior assistant message, which would false-positive on genuine
|
||||||
|
// multi-turn follow-ups.
|
||||||
|
if lastRole == "user" {
|
||||||
|
// Check if the last user message contains tool_result content
|
||||||
|
lastContent := arr[len(arr)-1].Get("content")
|
||||||
|
if lastContent.Exists() && lastContent.IsArray() {
|
||||||
|
for _, part := range lastContent.Array() {
|
||||||
|
if part.Get("type").String() == "tool_result" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if the second-to-last message is an assistant with tool_use
|
||||||
|
if len(arr) >= 2 {
|
||||||
|
prev := arr[len(arr)-2]
|
||||||
|
if prev.Get("role").String() == "assistant" {
|
||||||
|
prevContent := prev.Get("content")
|
||||||
|
if prevContent.Exists() && prevContent.IsArray() {
|
||||||
|
for _, part := range prevContent.Array() {
|
||||||
|
if part.Get("type").String() == "tool_use" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Responses API: check input array
|
||||||
if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() {
|
if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() {
|
||||||
arr := inputs.Array()
|
arr := inputs.Array()
|
||||||
for i := len(arr) - 1; i >= 0; i-- {
|
if len(arr) == 0 {
|
||||||
item := arr[i]
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Most Responses input items carry a top-level role.
|
// Check last item
|
||||||
if role := item.Get("role").String(); role != "" {
|
last := arr[len(arr)-1]
|
||||||
return role
|
if role := last.Get("role").String(); role == "assistant" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch last.Get("type").String() {
|
||||||
|
case "function_call", "function_call_arguments", "computer_call":
|
||||||
|
return true
|
||||||
|
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If last item is user-role, check for prior non-user items
|
||||||
|
for _, item := range arr {
|
||||||
|
if role := item.Get("role").String(); role == "assistant" {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
switch item.Get("type").String() {
|
switch item.Get("type").String() {
|
||||||
case "function_call", "function_call_arguments", "computer_call":
|
case "function_call", "function_call_output", "function_call_response",
|
||||||
return "assistant"
|
"function_call_arguments", "computer_call", "computer_call_output":
|
||||||
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
|
return true
|
||||||
return "tool"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectVisionContent checks if the request body contains vision/image content.
|
// detectVisionContent checks if the request body contains vision/image content.
|
||||||
@@ -572,6 +692,85 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copilotUnsupportedBetas lists beta headers that are Anthropic-specific and
|
||||||
|
// must not be forwarded to GitHub Copilot. The context-1m beta enables 1M
|
||||||
|
// context on Anthropic's API, but Copilot's Claude models are limited to
|
||||||
|
// ~128K-200K. Passing it through would not enable 1M on Copilot, but stripping
|
||||||
|
// it from the translated body avoids confusing downstream translators.
|
||||||
|
var copilotUnsupportedBetas = []string{
|
||||||
|
"context-1m-2025-08-07",
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripUnsupportedBetas removes Anthropic-specific beta entries from the
|
||||||
|
// translated request body. In OpenAI format the betas may appear under
|
||||||
|
// "metadata.betas" or a top-level "betas" array; in Claude format they sit at
|
||||||
|
// "betas". This function checks all known locations.
|
||||||
|
func stripUnsupportedBetas(body []byte) []byte {
|
||||||
|
betaPaths := []string{"betas", "metadata.betas"}
|
||||||
|
for _, path := range betaPaths {
|
||||||
|
arr := gjson.GetBytes(body, path)
|
||||||
|
if !arr.Exists() || !arr.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var filtered []string
|
||||||
|
changed := false
|
||||||
|
for _, item := range arr.Array() {
|
||||||
|
beta := item.String()
|
||||||
|
if isCopilotUnsupportedBeta(beta) {
|
||||||
|
changed = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, beta)
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
body, _ = sjson.DeleteBytes(body, path)
|
||||||
|
} else {
|
||||||
|
body, _ = sjson.SetBytes(body, path, filtered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCopilotUnsupportedBeta(beta string) bool {
|
||||||
|
return slices.Contains(copilotUnsupportedBetas, beta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeGitHubCopilotReasoningField maps Copilot's non-standard
|
||||||
|
// 'reasoning_text' field to the standard OpenAI 'reasoning_content' field
|
||||||
|
// that the SDK translator expects. This handles both streaming deltas
|
||||||
|
// (choices[].delta.reasoning_text) and non-streaming messages
|
||||||
|
// (choices[].message.reasoning_text). The field is only renamed when
|
||||||
|
// 'reasoning_content' is absent or null, preserving standard responses.
|
||||||
|
// All choices are processed to support n>1 requests.
|
||||||
|
func normalizeGitHubCopilotReasoningField(data []byte) []byte {
|
||||||
|
choices := gjson.GetBytes(data, "choices")
|
||||||
|
if !choices.Exists() || !choices.IsArray() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
for i := range choices.Array() {
|
||||||
|
// Non-streaming: choices[i].message.reasoning_text
|
||||||
|
msgRT := fmt.Sprintf("choices.%d.message.reasoning_text", i)
|
||||||
|
msgRC := fmt.Sprintf("choices.%d.message.reasoning_content", i)
|
||||||
|
if rt := gjson.GetBytes(data, msgRT); rt.Exists() && rt.String() != "" {
|
||||||
|
if rc := gjson.GetBytes(data, msgRC); !rc.Exists() || rc.Type == gjson.Null || rc.String() == "" {
|
||||||
|
data, _ = sjson.SetBytes(data, msgRC, rt.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Streaming: choices[i].delta.reasoning_text
|
||||||
|
deltaRT := fmt.Sprintf("choices.%d.delta.reasoning_text", i)
|
||||||
|
deltaRC := fmt.Sprintf("choices.%d.delta.reasoning_content", i)
|
||||||
|
if rt := gjson.GetBytes(data, deltaRT); rt.Exists() && rt.String() != "" {
|
||||||
|
if rc := gjson.GetBytes(data, deltaRC); !rc.Exists() || rc.Type == gjson.Null || rc.String() == "" {
|
||||||
|
data, _ = sjson.SetBytes(data, deltaRC, rt.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
|
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
|
||||||
if sourceFormat.String() == "openai-response" {
|
if sourceFormat.String() == "openai-response" {
|
||||||
return true
|
return true
|
||||||
@@ -596,12 +795,7 @@ func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func containsEndpoint(endpoints []string, endpoint string) bool {
|
func containsEndpoint(endpoints []string, endpoint string) bool {
|
||||||
for _, item := range endpoints {
|
return slices.Contains(endpoints, endpoint)
|
||||||
if item == endpoint {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// flattenAssistantContent converts assistant message content from array format
|
// flattenAssistantContent converts assistant message content from array format
|
||||||
@@ -856,6 +1050,32 @@ func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyGitHubCopilotResponsesDefaults sets required fields for the Responses API
|
||||||
|
// that both vscode-copilot-chat and pi-ai always include.
|
||||||
|
//
|
||||||
|
// References:
|
||||||
|
// - vscode-copilot-chat: src/platform/endpoint/node/responsesApi.ts
|
||||||
|
// - pi-ai (badlogic/pi-mono): packages/ai/src/providers/openai-responses.ts
|
||||||
|
func applyGitHubCopilotResponsesDefaults(body []byte) []byte {
|
||||||
|
// store: false — prevents request/response storage
|
||||||
|
if !gjson.GetBytes(body, "store").Exists() {
|
||||||
|
body, _ = sjson.SetBytes(body, "store", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// include: ["reasoning.encrypted_content"] — enables reasoning content
|
||||||
|
// reuse across turns, avoiding redundant computation
|
||||||
|
if !gjson.GetBytes(body, "include").Exists() {
|
||||||
|
body, _ = sjson.SetRawBytes(body, "include", []byte(`["reasoning.encrypted_content"]`))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If reasoning.effort is set but reasoning.summary is not, default to "auto"
|
||||||
|
if gjson.GetBytes(body, "reasoning.effort").Exists() && !gjson.GetBytes(body, "reasoning.summary").Exists() {
|
||||||
|
body, _ = sjson.SetBytes(body, "reasoning.summary", "auto")
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||||
tools := gjson.GetBytes(body, "tools")
|
tools := gjson.GetBytes(body, "tools")
|
||||||
if tools.Exists() {
|
if tools.Exists() {
|
||||||
@@ -1406,6 +1626,21 @@ func FetchGitHubCopilotModels(ctx context.Context, auth *cliproxyauth.Auth, cfg
|
|||||||
m.MaxCompletionTokens = defaultCopilotMaxCompletionTokens
|
m.MaxCompletionTokens = defaultCopilotMaxCompletionTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Override with real limits from the Copilot API when available.
|
||||||
|
// The API returns per-account limits (individual vs business) under
|
||||||
|
// capabilities.limits, which are more accurate than our static
|
||||||
|
// fallback values. We use max_prompt_tokens as ContextLength because
|
||||||
|
// that's the hard limit the Copilot API enforces on prompt size —
|
||||||
|
// exceeding it triggers "prompt token count exceeds the limit" errors.
|
||||||
|
if limits := entry.Limits(); limits != nil {
|
||||||
|
if limits.MaxPromptTokens > 0 {
|
||||||
|
m.ContextLength = limits.MaxPromptTokens
|
||||||
|
}
|
||||||
|
if limits.MaxOutputTokens > 0 {
|
||||||
|
m.MaxCompletionTokens = limits.MaxOutputTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
models = append(models, m)
|
models = append(models, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
package executor
|
package executor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
@@ -72,26 +75,39 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
|
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
|
||||||
t.Parallel()
|
// Not parallel: shares global model registry with DynamicRegistryWinsOverStatic.
|
||||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||||
t.Fatal("expected responses-only registry model to use /responses")
|
t.Fatal("expected responses-only registry model to use /responses")
|
||||||
}
|
}
|
||||||
|
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
|
||||||
|
t.Fatal("expected responses-only registry model to use /responses")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
|
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
|
||||||
t.Parallel()
|
// Not parallel: mutates global model registry, conflicts with RegistryResponsesOnlyModel.
|
||||||
|
|
||||||
reg := registry.GetGlobalRegistry()
|
reg := registry.GetGlobalRegistry()
|
||||||
clientID := "github-copilot-test-client"
|
clientID := "github-copilot-test-client"
|
||||||
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
|
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
|
||||||
ID: "gpt-5.4",
|
{
|
||||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
ID: "gpt-5.4",
|
||||||
}})
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.4-mini",
|
||||||
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
|
},
|
||||||
|
})
|
||||||
defer reg.UnregisterClient(clientID)
|
defer reg.UnregisterClient(clientID)
|
||||||
|
|
||||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||||
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
|
||||||
|
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||||
@@ -238,14 +254,14 @@ func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
||||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||||
if gjson.Get(out, "type").String() != "message" {
|
if gjson.GetBytes(out, "type").String() != "message" {
|
||||||
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
|
t.Fatalf("type = %q, want message", gjson.GetBytes(out, "type").String())
|
||||||
}
|
}
|
||||||
if gjson.Get(out, "content.0.type").String() != "text" {
|
if gjson.GetBytes(out, "content.0.type").String() != "text" {
|
||||||
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
|
t.Fatalf("content.0.type = %q, want text", gjson.GetBytes(out, "content.0.type").String())
|
||||||
}
|
}
|
||||||
if gjson.Get(out, "content.0.text").String() != "hello" {
|
if gjson.GetBytes(out, "content.0.text").String() != "hello" {
|
||||||
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
|
t.Fatalf("content.0.text = %q, want hello", gjson.GetBytes(out, "content.0.text").String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -253,14 +269,14 @@ func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *test
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
||||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||||
if gjson.Get(out, "content.0.type").String() != "tool_use" {
|
if gjson.GetBytes(out, "content.0.type").String() != "tool_use" {
|
||||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
|
t.Fatalf("content.0.type = %q, want tool_use", gjson.GetBytes(out, "content.0.type").String())
|
||||||
}
|
}
|
||||||
if gjson.Get(out, "content.0.name").String() != "sum" {
|
if gjson.GetBytes(out, "content.0.name").String() != "sum" {
|
||||||
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
|
t.Fatalf("content.0.name = %q, want sum", gjson.GetBytes(out, "content.0.name").String())
|
||||||
}
|
}
|
||||||
if gjson.Get(out, "stop_reason").String() != "tool_use" {
|
if gjson.GetBytes(out, "stop_reason").String() != "tool_use" {
|
||||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
|
t.Fatalf("stop_reason = %q, want tool_use", gjson.GetBytes(out, "stop_reason").String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,18 +285,24 @@ func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.
|
|||||||
var param any
|
var param any
|
||||||
|
|
||||||
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
||||||
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
|
if len(created) == 0 || !strings.Contains(string(created[0]), "message_start") {
|
||||||
t.Fatalf("created events = %#v, want message_start", created)
|
t.Fatalf("created events = %#v, want message_start", created)
|
||||||
}
|
}
|
||||||
|
|
||||||
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
||||||
joinedDelta := strings.Join(delta, "")
|
var joinedDelta string
|
||||||
|
for _, d := range delta {
|
||||||
|
joinedDelta += string(d)
|
||||||
|
}
|
||||||
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
||||||
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
||||||
}
|
}
|
||||||
|
|
||||||
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
||||||
joinedCompleted := strings.Join(completed, "")
|
var joinedCompleted string
|
||||||
|
for _, c := range completed {
|
||||||
|
joinedCompleted += string(c)
|
||||||
|
}
|
||||||
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
||||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||||
}
|
}
|
||||||
@@ -299,15 +321,17 @@ func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyHeaders_XInitiator_UserWhenLastRoleIsUser(t *testing.T) {
|
func TestApplyHeaders_XInitiator_AgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
e := &GitHubCopilotExecutor{}
|
e := &GitHubCopilotExecutor{}
|
||||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
// Last role governs the initiator decision.
|
// When the last role is "user" and the message contains tool_result content,
|
||||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
|
// the request is a continuation (e.g. Claude tool result translated to a
|
||||||
|
// synthetic user message). Should be "agent".
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu1","content":"file contents..."}]}]}`)
|
||||||
e.applyHeaders(req, "token", body)
|
e.applyHeaders(req, "token", body)
|
||||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||||
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
t.Fatalf("X-Initiator = %q, want agent (last user contains tool_result)", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,10 +339,11 @@ func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
e := &GitHubCopilotExecutor{}
|
e := &GitHubCopilotExecutor{}
|
||||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
// When the last message has role "tool", it's clearly agent-initiated.
|
||||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
|
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
|
||||||
e.applyHeaders(req, "token", body)
|
e.applyHeaders(req, "token", body)
|
||||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||||
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
|
t.Fatalf("X-Initiator = %q, want agent (last role is tool)", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,14 +358,15 @@ func TestApplyHeaders_XInitiator_InputArrayLastAssistantMessage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyHeaders_XInitiator_InputArrayLastUserMessage(t *testing.T) {
|
func TestApplyHeaders_XInitiator_InputArrayAgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
e := &GitHubCopilotExecutor{}
|
e := &GitHubCopilotExecutor{}
|
||||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
// Responses API: last item is user-role but history contains assistant → agent.
|
||||||
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
|
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
|
||||||
e.applyHeaders(req, "token", body)
|
e.applyHeaders(req, "token", body)
|
||||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||||
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
|
t.Fatalf("X-Initiator = %q, want agent (history has assistant)", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,6 +381,33 @@ func TestApplyHeaders_XInitiator_InputArrayLastFunctionCallOutput(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyHeaders_XInitiator_UserInMultiTurnNoTools(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
// Genuine multi-turn: user → assistant (plain text) → user follow-up.
|
||||||
|
// No tool messages → should be "user" (not a false-positive).
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"what is 2+2?"}]}`)
|
||||||
|
e.applyHeaders(req, "token", body)
|
||||||
|
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||||
|
t.Fatalf("X-Initiator = %q, want user (genuine multi-turn, no tools)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyHeaders_XInitiator_UserFollowUpAfterToolHistory(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
// User follow-up after a completed tool-use conversation.
|
||||||
|
// The last message is a genuine user question — should be "user", not "agent".
|
||||||
|
// This aligns with opencode's behavior: only active tool loops are agent-initiated.
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":[{"type":"tool_use","id":"tu1","name":"Read","input":{}}]},{"role":"tool","tool_call_id":"tu1","content":"file data"},{"role":"assistant","content":"I read the file."},{"role":"user","content":"What did we do so far?"}]}`)
|
||||||
|
e.applyHeaders(req, "token", body)
|
||||||
|
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||||
|
t.Fatalf("X-Initiator = %q, want user (genuine follow-up after tool history)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- Tests for x-github-api-version header (Problem M) ---
|
// --- Tests for x-github-api-version header (Problem M) ---
|
||||||
|
|
||||||
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||||
@@ -401,3 +454,364 @@ func TestDetectVisionContent_NoMessages(t *testing.T) {
|
|||||||
t.Fatal("expected no vision content when messages field is absent")
|
t.Fatal("expected no vision content when messages field is absent")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Tests for applyGitHubCopilotResponsesDefaults ---
|
||||||
|
|
||||||
|
func TestApplyGitHubCopilotResponsesDefaults_SetsAllDefaults(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(`{"input":"hello","reasoning":{"effort":"medium"}}`)
|
||||||
|
got := applyGitHubCopilotResponsesDefaults(body)
|
||||||
|
|
||||||
|
if gjson.GetBytes(got, "store").Bool() != false {
|
||||||
|
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
|
||||||
|
}
|
||||||
|
inc := gjson.GetBytes(got, "include")
|
||||||
|
if !inc.IsArray() || inc.Array()[0].String() != "reasoning.encrypted_content" {
|
||||||
|
t.Fatalf("include = %s, want [\"reasoning.encrypted_content\"]", inc.Raw)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(got, "reasoning.summary").String() != "auto" {
|
||||||
|
t.Fatalf("reasoning.summary = %q, want auto", gjson.GetBytes(got, "reasoning.summary").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyGitHubCopilotResponsesDefaults_DoesNotOverrideExisting(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(`{"input":"hello","store":true,"include":["other"],"reasoning":{"effort":"high","summary":"concise"}}`)
|
||||||
|
got := applyGitHubCopilotResponsesDefaults(body)
|
||||||
|
|
||||||
|
if gjson.GetBytes(got, "store").Bool() != true {
|
||||||
|
t.Fatalf("store should not be overridden, got %s", gjson.GetBytes(got, "store").Raw)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(got, "include").Array()[0].String() != "other" {
|
||||||
|
t.Fatalf("include should not be overridden, got %s", gjson.GetBytes(got, "include").Raw)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(got, "reasoning.summary").String() != "concise" {
|
||||||
|
t.Fatalf("reasoning.summary should not be overridden, got %q", gjson.GetBytes(got, "reasoning.summary").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyGitHubCopilotResponsesDefaults_NoReasoningEffort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(`{"input":"hello"}`)
|
||||||
|
got := applyGitHubCopilotResponsesDefaults(body)
|
||||||
|
|
||||||
|
if gjson.GetBytes(got, "store").Bool() != false {
|
||||||
|
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
|
||||||
|
}
|
||||||
|
// reasoning.summary should NOT be set when reasoning.effort is absent
|
||||||
|
if gjson.GetBytes(got, "reasoning.summary").Exists() {
|
||||||
|
t.Fatalf("reasoning.summary should not be set when reasoning.effort is absent, got %q", gjson.GetBytes(got, "reasoning.summary").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Tests for normalizeGitHubCopilotReasoningField ---
|
||||||
|
|
||||||
|
func TestNormalizeReasoningField_NonStreaming(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
data := []byte(`{"choices":[{"message":{"content":"hello","reasoning_text":"I think..."}}]}`)
|
||||||
|
got := normalizeGitHubCopilotReasoningField(data)
|
||||||
|
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||||
|
if rc != "I think..." {
|
||||||
|
t.Fatalf("reasoning_content = %q, want %q", rc, "I think...")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeReasoningField_Streaming(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
data := []byte(`{"choices":[{"delta":{"reasoning_text":"thinking delta"}}]}`)
|
||||||
|
got := normalizeGitHubCopilotReasoningField(data)
|
||||||
|
rc := gjson.GetBytes(got, "choices.0.delta.reasoning_content").String()
|
||||||
|
if rc != "thinking delta" {
|
||||||
|
t.Fatalf("reasoning_content = %q, want %q", rc, "thinking delta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeReasoningField_PreservesExistingReasoningContent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
data := []byte(`{"choices":[{"message":{"reasoning_text":"old","reasoning_content":"existing"}}]}`)
|
||||||
|
got := normalizeGitHubCopilotReasoningField(data)
|
||||||
|
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||||
|
if rc != "existing" {
|
||||||
|
t.Fatalf("reasoning_content = %q, want %q (should not overwrite)", rc, "existing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeReasoningField_MultiChoice(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
data := []byte(`{"choices":[{"message":{"reasoning_text":"thought-0"}},{"message":{"reasoning_text":"thought-1"}}]}`)
|
||||||
|
got := normalizeGitHubCopilotReasoningField(data)
|
||||||
|
rc0 := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||||
|
rc1 := gjson.GetBytes(got, "choices.1.message.reasoning_content").String()
|
||||||
|
if rc0 != "thought-0" {
|
||||||
|
t.Fatalf("choices[0].reasoning_content = %q, want %q", rc0, "thought-0")
|
||||||
|
}
|
||||||
|
if rc1 != "thought-1" {
|
||||||
|
t.Fatalf("choices[1].reasoning_content = %q, want %q", rc1, "thought-1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeReasoningField_NoChoices(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
data := []byte(`{"id":"chatcmpl-123"}`)
|
||||||
|
got := normalizeGitHubCopilotReasoningField(data)
|
||||||
|
if string(got) != string(data) {
|
||||||
|
t.Fatalf("expected no change, got %s", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyHeaders_OpenAIIntentValue(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||||
|
e.applyHeaders(req, "token", nil)
|
||||||
|
if got := req.Header.Get("Openai-Intent"); got != "conversation-edits" {
|
||||||
|
t.Fatalf("Openai-Intent = %q, want conversation-edits", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Tests for CountTokens (local tiktoken estimation) ---
|
||||||
|
|
||||||
|
func TestCountTokens_ReturnsPositiveCount(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, world!"}]}`)
|
||||||
|
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Payload: body,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountTokens() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(resp.Payload) == 0 {
|
||||||
|
t.Fatal("CountTokens() returned empty payload")
|
||||||
|
}
|
||||||
|
// The response should contain a positive token count.
|
||||||
|
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||||
|
if tokens <= 0 {
|
||||||
|
t.Fatalf("expected positive token count, got %d", tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCountTokens_ClaudeSourceFormatTranslates(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4","messages":[{"role":"user","content":"Tell me a joke"}],"max_tokens":1024}`)
|
||||||
|
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Payload: body,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountTokens() error: %v", err)
|
||||||
|
}
|
||||||
|
// Claude source format → should get input_tokens in response
|
||||||
|
inputTokens := gjson.GetBytes(resp.Payload, "input_tokens").Int()
|
||||||
|
if inputTokens <= 0 {
|
||||||
|
// Fallback: check usage.prompt_tokens (depends on translator registration)
|
||||||
|
promptTokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||||
|
if promptTokens <= 0 {
|
||||||
|
t.Fatalf("expected positive token count, got payload: %s", resp.Payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCountTokens_EmptyPayload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
e := &GitHubCopilotExecutor{}
|
||||||
|
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Payload: []byte(`{"model":"gpt-4o","messages":[]}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountTokens() error: %v", err)
|
||||||
|
}
|
||||||
|
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||||
|
// Empty messages should return 0 tokens.
|
||||||
|
if tokens != 0 {
|
||||||
|
t.Fatalf("expected 0 tokens for empty messages, got %d", tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripUnsupportedBetas_RemovesContext1M(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-opus-4.6","betas":["interleaved-thinking-2025-05-14","context-1m-2025-08-07","claude-code-20250219"],"messages":[]}`)
|
||||||
|
result := stripUnsupportedBetas(body)
|
||||||
|
|
||||||
|
betas := gjson.GetBytes(result, "betas")
|
||||||
|
if !betas.Exists() {
|
||||||
|
t.Fatal("betas field should still exist after stripping")
|
||||||
|
}
|
||||||
|
for _, item := range betas.Array() {
|
||||||
|
if item.String() == "context-1m-2025-08-07" {
|
||||||
|
t.Fatal("context-1m-2025-08-07 should have been stripped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Other betas should be preserved
|
||||||
|
found := false
|
||||||
|
for _, item := range betas.Array() {
|
||||||
|
if item.String() == "interleaved-thinking-2025-05-14" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("other betas should be preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripUnsupportedBetas_NoBetasField(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
body := []byte(`{"model":"gpt-4o","messages":[]}`)
|
||||||
|
result := stripUnsupportedBetas(body)
|
||||||
|
|
||||||
|
// Should be unchanged
|
||||||
|
if string(result) != string(body) {
|
||||||
|
t.Fatalf("body should be unchanged when no betas field exists, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripUnsupportedBetas_MetadataBetas(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-opus-4.6","metadata":{"betas":["context-1m-2025-08-07","other-beta"]},"messages":[]}`)
|
||||||
|
result := stripUnsupportedBetas(body)
|
||||||
|
|
||||||
|
betas := gjson.GetBytes(result, "metadata.betas")
|
||||||
|
if !betas.Exists() {
|
||||||
|
t.Fatal("metadata.betas field should still exist after stripping")
|
||||||
|
}
|
||||||
|
for _, item := range betas.Array() {
|
||||||
|
if item.String() == "context-1m-2025-08-07" {
|
||||||
|
t.Fatal("context-1m-2025-08-07 should have been stripped from metadata.betas")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if betas.Array()[0].String() != "other-beta" {
|
||||||
|
t.Fatal("other betas in metadata.betas should be preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripUnsupportedBetas_AllBetasStripped(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-opus-4.6","betas":["context-1m-2025-08-07"],"messages":[]}`)
|
||||||
|
result := stripUnsupportedBetas(body)
|
||||||
|
|
||||||
|
betas := gjson.GetBytes(result, "betas")
|
||||||
|
if betas.Exists() {
|
||||||
|
t.Fatal("betas field should be deleted when all betas are stripped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotModelEntry_Limits(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
capabilities map[string]any
|
||||||
|
wantNil bool
|
||||||
|
wantPrompt int
|
||||||
|
wantOutput int
|
||||||
|
wantContext int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil capabilities",
|
||||||
|
capabilities: nil,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no limits key",
|
||||||
|
capabilities: map[string]any{"family": "claude-opus-4.6"},
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "limits is not a map",
|
||||||
|
capabilities: map[string]any{"limits": "invalid"},
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all zero values",
|
||||||
|
capabilities: map[string]any{
|
||||||
|
"limits": map[string]any{
|
||||||
|
"max_context_window_tokens": float64(0),
|
||||||
|
"max_prompt_tokens": float64(0),
|
||||||
|
"max_output_tokens": float64(0),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "individual account limits (128K prompt)",
|
||||||
|
capabilities: map[string]any{
|
||||||
|
"limits": map[string]any{
|
||||||
|
"max_context_window_tokens": float64(144000),
|
||||||
|
"max_prompt_tokens": float64(128000),
|
||||||
|
"max_output_tokens": float64(64000),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantNil: false,
|
||||||
|
wantPrompt: 128000,
|
||||||
|
wantOutput: 64000,
|
||||||
|
wantContext: 144000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "business account limits (168K prompt)",
|
||||||
|
capabilities: map[string]any{
|
||||||
|
"limits": map[string]any{
|
||||||
|
"max_context_window_tokens": float64(200000),
|
||||||
|
"max_prompt_tokens": float64(168000),
|
||||||
|
"max_output_tokens": float64(32000),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantNil: false,
|
||||||
|
wantPrompt: 168000,
|
||||||
|
wantOutput: 32000,
|
||||||
|
wantContext: 200000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial limits (only prompt)",
|
||||||
|
capabilities: map[string]any{
|
||||||
|
"limits": map[string]any{
|
||||||
|
"max_prompt_tokens": float64(128000),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantNil: false,
|
||||||
|
wantPrompt: 128000,
|
||||||
|
wantOutput: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
entry := copilotauth.CopilotModelEntry{
|
||||||
|
ID: "claude-opus-4.6",
|
||||||
|
Capabilities: tt.capabilities,
|
||||||
|
}
|
||||||
|
limits := entry.Limits()
|
||||||
|
if tt.wantNil {
|
||||||
|
if limits != nil {
|
||||||
|
t.Fatalf("expected nil limits, got %+v", limits)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if limits == nil {
|
||||||
|
t.Fatal("expected non-nil limits, got nil")
|
||||||
|
}
|
||||||
|
if limits.MaxPromptTokens != tt.wantPrompt {
|
||||||
|
t.Errorf("MaxPromptTokens = %d, want %d", limits.MaxPromptTokens, tt.wantPrompt)
|
||||||
|
}
|
||||||
|
if limits.MaxOutputTokens != tt.wantOutput {
|
||||||
|
t.Errorf("MaxOutputTokens = %d, want %d", limits.MaxOutputTokens, tt.wantOutput)
|
||||||
|
}
|
||||||
|
if tt.wantContext > 0 && limits.MaxContextWindowTokens != tt.wantContext {
|
||||||
|
t.Errorf("MaxContextWindowTokens = %d, want %d", limits.MaxContextWindowTokens, tt.wantContext)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,12 +30,20 @@ const (
|
|||||||
gitLabChatEndpoint = "/api/v4/chat/completions"
|
gitLabChatEndpoint = "/api/v4/chat/completions"
|
||||||
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
|
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
|
||||||
gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming"
|
gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming"
|
||||||
|
gitLabContext1MBeta = "context-1m-2025-08-07"
|
||||||
|
gitLabNativeUserAgent = "CLIProxyAPIPlus/GitLab-Duo"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GitLabExecutor struct {
|
type GitLabExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type gitLabCatalogModel struct {
|
||||||
|
ID string
|
||||||
|
DisplayName string
|
||||||
|
Provider string
|
||||||
|
}
|
||||||
|
|
||||||
type gitLabPrompt struct {
|
type gitLabPrompt struct {
|
||||||
Instruction string
|
Instruction string
|
||||||
FileName string
|
FileName string
|
||||||
@@ -53,6 +61,23 @@ type gitLabOpenAIStreamState struct {
|
|||||||
Finished bool
|
Finished bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var gitLabAgenticCatalog = []gitLabCatalogModel{
|
||||||
|
{ID: "duo-chat-gpt-5-1", DisplayName: "GitLab Duo (GPT-5.1)", Provider: "openai"},
|
||||||
|
{ID: "duo-chat-opus-4-6", DisplayName: "GitLab Duo (Claude Opus 4.6)", Provider: "anthropic"},
|
||||||
|
{ID: "duo-chat-opus-4-5", DisplayName: "GitLab Duo (Claude Opus 4.5)", Provider: "anthropic"},
|
||||||
|
{ID: "duo-chat-sonnet-4-6", DisplayName: "GitLab Duo (Claude Sonnet 4.6)", Provider: "anthropic"},
|
||||||
|
{ID: "duo-chat-sonnet-4-5", DisplayName: "GitLab Duo (Claude Sonnet 4.5)", Provider: "anthropic"},
|
||||||
|
{ID: "duo-chat-gpt-5-mini", DisplayName: "GitLab Duo (GPT-5 Mini)", Provider: "openai"},
|
||||||
|
{ID: "duo-chat-gpt-5-2", DisplayName: "GitLab Duo (GPT-5.2)", Provider: "openai"},
|
||||||
|
{ID: "duo-chat-gpt-5-2-codex", DisplayName: "GitLab Duo (GPT-5.2 Codex)", Provider: "openai"},
|
||||||
|
{ID: "duo-chat-gpt-5-codex", DisplayName: "GitLab Duo (GPT-5 Codex)", Provider: "openai"},
|
||||||
|
{ID: "duo-chat-haiku-4-5", DisplayName: "GitLab Duo (Claude Haiku 4.5)", Provider: "anthropic"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var gitLabModelAliases = map[string]string{
|
||||||
|
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
|
||||||
|
}
|
||||||
|
|
||||||
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
|
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
|
||||||
return &GitLabExecutor{cfg: cfg}
|
return &GitLabExecutor{cfg: cfg}
|
||||||
}
|
}
|
||||||
@@ -249,12 +274,12 @@ func (e *GitLabExecutor) nativeGateway(
|
|||||||
auth *cliproxyauth.Auth,
|
auth *cliproxyauth.Auth,
|
||||||
req cliproxyexecutor.Request,
|
req cliproxyexecutor.Request,
|
||||||
) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) {
|
) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) {
|
||||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
|
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, req.Model); ok {
|
||||||
nativeReq := req
|
nativeReq := req
|
||||||
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
||||||
return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true
|
return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true
|
||||||
}
|
}
|
||||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
|
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, req.Model); ok {
|
||||||
nativeReq := req
|
nativeReq := req
|
||||||
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
|
||||||
return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true
|
return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true
|
||||||
@@ -263,10 +288,10 @@ func (e *GitLabExecutor) nativeGateway(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) {
|
func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) {
|
||||||
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
|
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, ""); ok {
|
||||||
return NewClaudeExecutor(e.cfg), nativeAuth
|
return NewClaudeExecutor(e.cfg), nativeAuth
|
||||||
}
|
}
|
||||||
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
|
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, ""); ok {
|
||||||
return NewCodexExecutor(e.cfg), nativeAuth
|
return NewCodexExecutor(e.cfg), nativeAuth
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -664,7 +689,7 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
|||||||
if auth != nil {
|
if auth != nil {
|
||||||
util.ApplyCustomHeadersFromAttrs(req, auth.Attributes)
|
util.ApplyCustomHeadersFromAttrs(req, auth.Attributes)
|
||||||
}
|
}
|
||||||
for key, value := range gitLabGatewayHeaders(auth) {
|
for key, value := range gitLabGatewayHeaders(auth, "") {
|
||||||
if key == "" || value == "" {
|
if key == "" || value == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -672,34 +697,40 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func gitLabGatewayHeaders(auth *cliproxyauth.Auth) map[string]string {
|
func gitLabGatewayHeaders(auth *cliproxyauth.Auth, targetProvider string) map[string]string {
|
||||||
if auth == nil || auth.Metadata == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
raw, ok := auth.Metadata["duo_gateway_headers"]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := make(map[string]string)
|
out := make(map[string]string)
|
||||||
switch typed := raw.(type) {
|
if auth != nil && auth.Metadata != nil {
|
||||||
case map[string]string:
|
raw, ok := auth.Metadata["duo_gateway_headers"]
|
||||||
for key, value := range typed {
|
if ok {
|
||||||
key = strings.TrimSpace(key)
|
switch typed := raw.(type) {
|
||||||
value = strings.TrimSpace(value)
|
case map[string]string:
|
||||||
if key != "" && value != "" {
|
for key, value := range typed {
|
||||||
out[key] = value
|
key = strings.TrimSpace(key)
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if key != "" && value != "" {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
for key, value := range typed {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
strValue := strings.TrimSpace(fmt.Sprint(value))
|
||||||
|
if strValue != "" {
|
||||||
|
out[key] = strValue
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case map[string]any:
|
}
|
||||||
for key, value := range typed {
|
if _, ok := out["User-Agent"]; !ok {
|
||||||
key = strings.TrimSpace(key)
|
out["User-Agent"] = gitLabNativeUserAgent
|
||||||
if key == "" {
|
}
|
||||||
continue
|
if strings.EqualFold(strings.TrimSpace(targetProvider), "openai") {
|
||||||
}
|
if _, ok := out["anthropic-beta"]; !ok {
|
||||||
strValue := strings.TrimSpace(fmt.Sprint(value))
|
out["anthropic-beta"] = gitLabContext1MBeta
|
||||||
if strValue != "" {
|
|
||||||
out[key] = strValue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(out) == 0 {
|
if len(out) == 0 {
|
||||||
@@ -989,8 +1020,8 @@ func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64)
|
|||||||
return promptTokens, int64(completionCount)
|
return promptTokens, int64(completionCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
|
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
|
||||||
if !gitLabUsesAnthropicGateway(auth) {
|
if !gitLabUsesAnthropicGateway(auth, requestedModel) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
baseURL := gitLabAnthropicGatewayBaseURL(auth)
|
baseURL := gitLabAnthropicGatewayBaseURL(auth)
|
||||||
@@ -1006,7 +1037,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
nativeAuth.Attributes["api_key"] = token
|
nativeAuth.Attributes["api_key"] = token
|
||||||
nativeAuth.Attributes["base_url"] = baseURL
|
nativeAuth.Attributes["base_url"] = baseURL
|
||||||
for key, value := range gitLabGatewayHeaders(auth) {
|
nativeAuth.Attributes["gitlab_duo_force_context_1m"] = "true"
|
||||||
|
for key, value := range gitLabGatewayHeaders(auth, "anthropic") {
|
||||||
if key == "" || value == "" {
|
if key == "" || value == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1015,8 +1047,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
|
|||||||
return nativeAuth, true
|
return nativeAuth, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
|
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
|
||||||
if !gitLabUsesOpenAIGateway(auth) {
|
if !gitLabUsesOpenAIGateway(auth, requestedModel) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
baseURL := gitLabOpenAIGatewayBaseURL(auth)
|
baseURL := gitLabOpenAIGatewayBaseURL(auth)
|
||||||
@@ -1032,7 +1064,7 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
|
|||||||
}
|
}
|
||||||
nativeAuth.Attributes["api_key"] = token
|
nativeAuth.Attributes["api_key"] = token
|
||||||
nativeAuth.Attributes["base_url"] = baseURL
|
nativeAuth.Attributes["base_url"] = baseURL
|
||||||
for key, value := range gitLabGatewayHeaders(auth) {
|
for key, value := range gitLabGatewayHeaders(auth, "openai") {
|
||||||
if key == "" || value == "" {
|
if key == "" || value == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1041,34 +1073,41 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
|
|||||||
return nativeAuth, true
|
return nativeAuth, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth) bool {
|
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
|
||||||
if auth == nil || auth.Metadata == nil {
|
if auth == nil || auth.Metadata == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
provider := gitLabGatewayProvider(auth, requestedModel)
|
||||||
if provider == "" {
|
|
||||||
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
|
|
||||||
provider = inferGitLabProviderFromModel(modelName)
|
|
||||||
}
|
|
||||||
return provider == "anthropic" &&
|
return provider == "anthropic" &&
|
||||||
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
||||||
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth) bool {
|
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
|
||||||
if auth == nil || auth.Metadata == nil {
|
if auth == nil || auth.Metadata == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
provider := gitLabGatewayProvider(auth, requestedModel)
|
||||||
if provider == "" {
|
|
||||||
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
|
|
||||||
provider = inferGitLabProviderFromModel(modelName)
|
|
||||||
}
|
|
||||||
return provider == "openai" &&
|
return provider == "openai" &&
|
||||||
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
|
||||||
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func gitLabGatewayProvider(auth *cliproxyauth.Auth, requestedModel string) string {
|
||||||
|
modelName := strings.TrimSpace(gitLabResolvedModel(auth, requestedModel))
|
||||||
|
if provider := inferGitLabProviderFromModel(modelName); provider != "" {
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
|
||||||
|
if provider == "" {
|
||||||
|
provider = inferGitLabProviderFromModel(gitLabMetadataString(auth.Metadata, "model_name"))
|
||||||
|
}
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
|
||||||
func inferGitLabProviderFromModel(model string) string {
|
func inferGitLabProviderFromModel(model string) string {
|
||||||
model = strings.ToLower(strings.TrimSpace(model))
|
model = strings.ToLower(strings.TrimSpace(model))
|
||||||
switch {
|
switch {
|
||||||
@@ -1151,6 +1190,9 @@ func gitLabBaseURL(auth *cliproxyauth.Auth) string {
|
|||||||
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
|
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
|
||||||
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
|
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
|
||||||
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
|
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
|
||||||
|
if mapped, ok := gitLabModelAliases[strings.ToLower(requested)]; ok && strings.TrimSpace(mapped) != "" {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
return requested
|
return requested
|
||||||
}
|
}
|
||||||
if auth != nil && auth.Metadata != nil {
|
if auth != nil && auth.Metadata != nil {
|
||||||
@@ -1277,8 +1319,8 @@ func gitLabAuthKind(method string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
||||||
models := make([]*registry.ModelInfo, 0, 4)
|
models := make([]*registry.ModelInfo, 0, len(gitLabAgenticCatalog)+4)
|
||||||
seen := make(map[string]struct{}, 4)
|
seen := make(map[string]struct{}, len(gitLabAgenticCatalog)+4)
|
||||||
addModel := func(id, displayName, provider string) {
|
addModel := func(id, displayName, provider string) {
|
||||||
id = strings.TrimSpace(id)
|
id = strings.TrimSpace(id)
|
||||||
if id == "" {
|
if id == "" {
|
||||||
@@ -1302,6 +1344,18 @@ func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
addModel("gitlab-duo", "GitLab Duo", "gitlab")
|
addModel("gitlab-duo", "GitLab Duo", "gitlab")
|
||||||
|
for _, model := range gitLabAgenticCatalog {
|
||||||
|
addModel(model.ID, model.DisplayName, model.Provider)
|
||||||
|
}
|
||||||
|
for alias, upstream := range gitLabModelAliases {
|
||||||
|
target := strings.TrimSpace(upstream)
|
||||||
|
displayName := "GitLab Duo Alias"
|
||||||
|
provider := strings.TrimSpace(inferGitLabProviderFromModel(target))
|
||||||
|
if provider != "" {
|
||||||
|
displayName = fmt.Sprintf("GitLab Duo Alias (%s)", provider)
|
||||||
|
}
|
||||||
|
addModel(alias, displayName, provider)
|
||||||
|
}
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -217,6 +217,69 @@ func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
|
||||||
|
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
|
||||||
|
var gotPath string
|
||||||
|
var gotModel string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
gotAuthHeader = r.Header.Get("Authorization")
|
||||||
|
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||||
|
gotBetaHeader = r.Header.Get("anthropic-beta")
|
||||||
|
gotUserAgent = r.Header.Get("User-Agent")
|
||||||
|
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
exec := NewGitLabExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "gitlab",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"duo_gateway_base_url": srv.URL,
|
||||||
|
"duo_gateway_token": "gateway-token",
|
||||||
|
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req := cliproxyexecutor.Request{
|
||||||
|
Model: "duo-chat-gpt-5-codex",
|
||||||
|
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||||
|
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||||
|
}
|
||||||
|
if gotAuthHeader != "Bearer gateway-token" {
|
||||||
|
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||||
|
}
|
||||||
|
if gotRealmHeader != "saas" {
|
||||||
|
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||||
|
}
|
||||||
|
if gotBetaHeader != gitLabContext1MBeta {
|
||||||
|
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
|
||||||
|
}
|
||||||
|
if gotUserAgent != gitLabNativeUserAgent {
|
||||||
|
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||||
|
}
|
||||||
|
if gotModel != "duo-chat-gpt-5-codex" {
|
||||||
|
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
|
||||||
|
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
@@ -251,13 +314,12 @@ func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
|||||||
ID: "gitlab-auth.json",
|
ID: "gitlab-auth.json",
|
||||||
Provider: "gitlab",
|
Provider: "gitlab",
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"base_url": srv.URL,
|
"base_url": srv.URL,
|
||||||
"access_token": "oauth-access",
|
"access_token": "oauth-access",
|
||||||
"refresh_token": "oauth-refresh",
|
"refresh_token": "oauth-refresh",
|
||||||
"oauth_client_id": "client-id",
|
"oauth_client_id": "client-id",
|
||||||
"oauth_client_secret": "client-secret",
|
"auth_method": "oauth",
|
||||||
"auth_method": "oauth",
|
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -397,9 +459,11 @@ func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||||
var gotPath string
|
var gotPath, gotBetaHeader, gotUserAgent string
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
gotPath = r.URL.Path
|
gotPath = r.URL.Path
|
||||||
|
gotBetaHeader = r.Header.Get("Anthropic-Beta")
|
||||||
|
gotUserAgent = r.Header.Get("User-Agent")
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
_, _ = w.Write([]byte("event: message_start\n"))
|
_, _ = w.Write([]byte("event: message_start\n"))
|
||||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||||
@@ -441,6 +505,12 @@ func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
|||||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
|
||||||
|
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
|
||||||
|
}
|
||||||
|
if gotUserAgent != gitLabNativeUserAgent {
|
||||||
|
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||||
|
}
|
||||||
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
||||||
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type codexCache struct {
|
type CodexCache struct {
|
||||||
ID string
|
ID string
|
||||||
Expire time.Time
|
Expire time.Time
|
||||||
}
|
}
|
||||||
@@ -13,7 +13,7 @@ type codexCache struct {
|
|||||||
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
|
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
|
||||||
// Protected by codexCacheMu. Entries expire after 1 hour.
|
// Protected by codexCacheMu. Entries expire after 1 hour.
|
||||||
var (
|
var (
|
||||||
codexCacheMap = make(map[string]codexCache)
|
codexCacheMap = make(map[string]CodexCache)
|
||||||
codexCacheMu sync.RWMutex
|
codexCacheMu sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,20 +50,20 @@ func purgeExpiredCodexCache() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
|
// GetCodexCache retrieves a cached entry, returning ok=false if not found or expired.
|
||||||
func getCodexCache(key string) (codexCache, bool) {
|
func GetCodexCache(key string) (CodexCache, bool) {
|
||||||
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||||
codexCacheMu.RLock()
|
codexCacheMu.RLock()
|
||||||
cache, ok := codexCacheMap[key]
|
cache, ok := codexCacheMap[key]
|
||||||
codexCacheMu.RUnlock()
|
codexCacheMu.RUnlock()
|
||||||
if !ok || cache.Expire.Before(time.Now()) {
|
if !ok || cache.Expire.Before(time.Now()) {
|
||||||
return codexCache{}, false
|
return CodexCache{}, false
|
||||||
}
|
}
|
||||||
return cache, true
|
return cache, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// setCodexCache stores a cache entry.
|
// SetCodexCache stores a cache entry.
|
||||||
func setCodexCache(key string, cache codexCache) {
|
func SetCodexCache(key string, cache CodexCache) {
|
||||||
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||||
codexCacheMu.Lock()
|
codexCacheMu.Lock()
|
||||||
codexCacheMap[key] = cache
|
codexCacheMap[key] = cache
|
||||||
38
internal/runtime/executor/helps/claude_builtin_tools.go
Normal file
38
internal/runtime/executor/helps/claude_builtin_tools.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package helps
|
||||||
|
|
||||||
|
import "github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
var defaultClaudeBuiltinToolNames = []string{
|
||||||
|
"web_search",
|
||||||
|
"code_execution",
|
||||||
|
"text_editor",
|
||||||
|
"computer",
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClaudeBuiltinToolRegistry() map[string]bool {
|
||||||
|
registry := make(map[string]bool, len(defaultClaudeBuiltinToolNames))
|
||||||
|
for _, name := range defaultClaudeBuiltinToolNames {
|
||||||
|
registry[name] = true
|
||||||
|
}
|
||||||
|
return registry
|
||||||
|
}
|
||||||
|
|
||||||
|
func AugmentClaudeBuiltinToolRegistry(body []byte, registry map[string]bool) map[string]bool {
|
||||||
|
if registry == nil {
|
||||||
|
registry = newClaudeBuiltinToolRegistry()
|
||||||
|
}
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
|
return registry
|
||||||
|
}
|
||||||
|
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||||
|
if tool.Get("type").String() == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if name := tool.Get("name").String(); name != "" {
|
||||||
|
registry[name] = true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return registry
|
||||||
|
}
|
||||||
32
internal/runtime/executor/helps/claude_builtin_tools_test.go
Normal file
32
internal/runtime/executor/helps/claude_builtin_tools_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package helps
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestClaudeBuiltinToolRegistry_DefaultSeedFallback(t *testing.T) {
|
||||||
|
registry := AugmentClaudeBuiltinToolRegistry(nil, nil)
|
||||||
|
for _, name := range defaultClaudeBuiltinToolNames {
|
||||||
|
if !registry[name] {
|
||||||
|
t.Fatalf("default builtin %q missing from fallback registry", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeBuiltinToolRegistry_AugmentsTypedBuiltinsFromBody(t *testing.T) {
|
||||||
|
registry := AugmentClaudeBuiltinToolRegistry([]byte(`{
|
||||||
|
"tools": [
|
||||||
|
{"type": "web_search_20250305", "name": "web_search"},
|
||||||
|
{"type": "custom_builtin_20250401", "name": "special_builtin"},
|
||||||
|
{"name": "Read"}
|
||||||
|
]
|
||||||
|
}`), nil)
|
||||||
|
|
||||||
|
if !registry["web_search"] {
|
||||||
|
t.Fatal("expected default typed builtin web_search in registry")
|
||||||
|
}
|
||||||
|
if !registry["special_builtin"] {
|
||||||
|
t.Fatal("expected typed builtin from body to be added to registry")
|
||||||
|
}
|
||||||
|
if registry["Read"] {
|
||||||
|
t.Fatal("expected untyped custom tool to stay out of builtin registry")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
@@ -32,7 +32,7 @@ var (
|
|||||||
claudeDeviceProfileCacheMu sync.RWMutex
|
claudeDeviceProfileCacheMu sync.RWMutex
|
||||||
claudeDeviceProfileCacheCleanupOnce sync.Once
|
claudeDeviceProfileCacheCleanupOnce sync.Once
|
||||||
|
|
||||||
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
|
ClaudeDeviceProfileBeforeCandidateStore func(ClaudeDeviceProfile)
|
||||||
)
|
)
|
||||||
|
|
||||||
type claudeCLIVersion struct {
|
type claudeCLIVersion struct {
|
||||||
@@ -63,29 +63,43 @@ func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type claudeDeviceProfile struct {
|
type ClaudeDeviceProfile struct {
|
||||||
UserAgent string
|
UserAgent string
|
||||||
PackageVersion string
|
PackageVersion string
|
||||||
RuntimeVersion string
|
RuntimeVersion string
|
||||||
OS string
|
OS string
|
||||||
Arch string
|
Arch string
|
||||||
Version claudeCLIVersion
|
version claudeCLIVersion
|
||||||
HasVersion bool
|
hasVersion bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type claudeDeviceProfileCacheEntry struct {
|
type claudeDeviceProfileCacheEntry struct {
|
||||||
profile claudeDeviceProfile
|
profile ClaudeDeviceProfile
|
||||||
expire time.Time
|
expire time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
func ClaudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
|
||||||
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
|
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
func ResetClaudeDeviceProfileCache() {
|
||||||
|
claudeDeviceProfileCacheMu.Lock()
|
||||||
|
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
|
||||||
|
claudeDeviceProfileCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapStainlessOS() string {
|
||||||
|
return mapStainlessOS()
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapStainlessArch() string {
|
||||||
|
return mapStainlessArch()
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultClaudeDeviceProfile(cfg *config.Config) ClaudeDeviceProfile {
|
||||||
hdrDefault := func(cfgVal, fallback string) string {
|
hdrDefault := func(cfgVal, fallback string) string {
|
||||||
if strings.TrimSpace(cfgVal) != "" {
|
if strings.TrimSpace(cfgVal) != "" {
|
||||||
return strings.TrimSpace(cfgVal)
|
return strings.TrimSpace(cfgVal)
|
||||||
@@ -98,7 +112,7 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
|||||||
hd = cfg.ClaudeHeaderDefaults
|
hd = cfg.ClaudeHeaderDefaults
|
||||||
}
|
}
|
||||||
|
|
||||||
profile := claudeDeviceProfile{
|
profile := ClaudeDeviceProfile{
|
||||||
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
|
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
|
||||||
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
|
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
|
||||||
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
|
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
|
||||||
@@ -106,8 +120,8 @@ func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
|
|||||||
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
|
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
|
||||||
}
|
}
|
||||||
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||||
profile.Version = version
|
profile.version = version
|
||||||
profile.HasVersion = true
|
profile.hasVersion = true
|
||||||
}
|
}
|
||||||
return profile
|
return profile
|
||||||
}
|
}
|
||||||
@@ -162,17 +176,17 @@ func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
|
|||||||
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
|
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
|
func shouldUpgradeClaudeDeviceProfile(candidate, current ClaudeDeviceProfile) bool {
|
||||||
if candidate.UserAgent == "" || !candidate.HasVersion {
|
if candidate.UserAgent == "" || !candidate.hasVersion {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if current.UserAgent == "" || !current.HasVersion {
|
if current.UserAgent == "" || !current.hasVersion {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return candidate.Version.Compare(current.Version) > 0
|
return candidate.version.Compare(current.version) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
func pinClaudeDeviceProfilePlatform(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
|
||||||
profile.OS = baseline.OS
|
profile.OS = baseline.OS
|
||||||
profile.Arch = baseline.Arch
|
profile.Arch = baseline.Arch
|
||||||
return profile
|
return profile
|
||||||
@@ -180,38 +194,38 @@ func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claud
|
|||||||
|
|
||||||
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
|
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
|
||||||
// baseline platform and enforces the baseline software fingerprint as a floor.
|
// baseline platform and enforces the baseline software fingerprint as a floor.
|
||||||
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
|
func normalizeClaudeDeviceProfile(profile, baseline ClaudeDeviceProfile) ClaudeDeviceProfile {
|
||||||
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
|
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
|
||||||
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
if profile.UserAgent == "" || !profile.hasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
|
||||||
profile.UserAgent = baseline.UserAgent
|
profile.UserAgent = baseline.UserAgent
|
||||||
profile.PackageVersion = baseline.PackageVersion
|
profile.PackageVersion = baseline.PackageVersion
|
||||||
profile.RuntimeVersion = baseline.RuntimeVersion
|
profile.RuntimeVersion = baseline.RuntimeVersion
|
||||||
profile.Version = baseline.Version
|
profile.version = baseline.version
|
||||||
profile.HasVersion = baseline.HasVersion
|
profile.hasVersion = baseline.hasVersion
|
||||||
}
|
}
|
||||||
return profile
|
return profile
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
|
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (ClaudeDeviceProfile, bool) {
|
||||||
if headers == nil {
|
if headers == nil {
|
||||||
return claudeDeviceProfile{}, false
|
return ClaudeDeviceProfile{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
||||||
version, ok := parseClaudeCLIVersion(userAgent)
|
version, ok := parseClaudeCLIVersion(userAgent)
|
||||||
if !ok {
|
if !ok {
|
||||||
return claudeDeviceProfile{}, false
|
return ClaudeDeviceProfile{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
baseline := defaultClaudeDeviceProfile(cfg)
|
baseline := defaultClaudeDeviceProfile(cfg)
|
||||||
profile := claudeDeviceProfile{
|
profile := ClaudeDeviceProfile{
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
|
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
|
||||||
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
|
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
|
||||||
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
|
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
|
||||||
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
|
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
|
||||||
Version: version,
|
version: version,
|
||||||
HasVersion: true,
|
hasVersion: true,
|
||||||
}
|
}
|
||||||
return profile, true
|
return profile, true
|
||||||
}
|
}
|
||||||
@@ -263,7 +277,7 @@ func purgeExpiredClaudeDeviceProfiles() {
|
|||||||
claudeDeviceProfileCacheMu.Unlock()
|
claudeDeviceProfileCacheMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
|
func ResolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) ClaudeDeviceProfile {
|
||||||
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
|
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
|
||||||
|
|
||||||
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
|
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
|
||||||
@@ -283,8 +297,8 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
|
|||||||
claudeDeviceProfileCacheMu.RUnlock()
|
claudeDeviceProfileCacheMu.RUnlock()
|
||||||
|
|
||||||
if hasCandidate {
|
if hasCandidate {
|
||||||
if claudeDeviceProfileBeforeCandidateStore != nil {
|
if ClaudeDeviceProfileBeforeCandidateStore != nil {
|
||||||
claudeDeviceProfileBeforeCandidateStore(candidate)
|
ClaudeDeviceProfileBeforeCandidateStore(candidate)
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeDeviceProfileCacheMu.Lock()
|
claudeDeviceProfileCacheMu.Lock()
|
||||||
@@ -324,7 +338,7 @@ func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers
|
|||||||
return baseline
|
return baseline
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
|
func ApplyClaudeDeviceProfileHeaders(r *http.Request, profile ClaudeDeviceProfile) {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -344,7 +358,17 @@ func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfil
|
|||||||
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
r.Header.Set("X-Stainless-Arch", profile.Arch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
// DefaultClaudeVersion returns the version string (e.g. "2.1.63") from the
|
||||||
|
// current baseline device profile. It extracts the version from the User-Agent.
|
||||||
|
func DefaultClaudeVersion(cfg *config.Config) string {
|
||||||
|
profile := defaultClaudeDeviceProfile(cfg)
|
||||||
|
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
|
||||||
|
return strconv.Itoa(version.major) + "." + strconv.Itoa(version.minor) + "." + strconv.Itoa(version.patch)
|
||||||
|
}
|
||||||
|
return "2.1.63"
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -18,9 +18,9 @@ type SensitiveWordMatcher struct {
|
|||||||
regex *regexp.Regexp
|
regex *regexp.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildSensitiveWordMatcher compiles a regex from the word list.
|
// BuildSensitiveWordMatcher compiles a regex from the word list.
|
||||||
// Words are sorted by length (longest first) for proper matching.
|
// Words are sorted by length (longest first) for proper matching.
|
||||||
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
|
func BuildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
|
||||||
if len(words) == 0 {
|
if len(words) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -81,9 +81,9 @@ func (m *SensitiveWordMatcher) obfuscateText(text string) string {
|
|||||||
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
|
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
|
||||||
}
|
}
|
||||||
|
|
||||||
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
|
// ObfuscateSensitiveWords processes the payload and obfuscates sensitive words
|
||||||
// in system blocks and message content.
|
// in system blocks and message content.
|
||||||
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
func ObfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||||
if matcher == nil || matcher.regex == nil {
|
if matcher == nil || matcher.regex == nil {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
@@ -28,9 +28,17 @@ func isValidUserID(userID string) bool {
|
|||||||
return userIDPattern.MatchString(userID)
|
return userIDPattern.MatchString(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
|
func GenerateFakeUserID() string {
|
||||||
|
return generateFakeUserID()
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsValidUserID(userID string) bool {
|
||||||
|
return isValidUserID(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldCloak determines if request should be cloaked based on config and client User-Agent.
|
||||||
// Returns true if cloaking should be applied.
|
// Returns true if cloaking should be applied.
|
||||||
func shouldCloak(cloakMode string, userAgent string) bool {
|
func ShouldCloak(cloakMode string, userAgent string) bool {
|
||||||
switch strings.ToLower(cloakMode) {
|
switch strings.ToLower(cloakMode) {
|
||||||
case "always":
|
case "always":
|
||||||
return true
|
return true
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"html"
|
"html"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -19,13 +20,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
|
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
|
||||||
apiRequestKey = "API_REQUEST"
|
apiRequestKey = "API_REQUEST"
|
||||||
apiResponseKey = "API_RESPONSE"
|
apiResponseKey = "API_RESPONSE"
|
||||||
|
apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE"
|
||||||
)
|
)
|
||||||
|
|
||||||
// upstreamRequestLog captures the outbound upstream request details for logging.
|
// UpstreamRequestLog captures the outbound upstream request details for logging.
|
||||||
type upstreamRequestLog struct {
|
type UpstreamRequestLog struct {
|
||||||
URL string
|
URL string
|
||||||
Method string
|
Method string
|
||||||
Headers http.Header
|
Headers http.Header
|
||||||
@@ -46,11 +48,12 @@ type upstreamAttempt struct {
|
|||||||
headersWritten bool
|
headersWritten bool
|
||||||
bodyStarted bool
|
bodyStarted bool
|
||||||
bodyHasContent bool
|
bodyHasContent bool
|
||||||
|
prevWasSSEEvent bool
|
||||||
errorWritten bool
|
errorWritten bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// recordAPIRequest stores the upstream request metadata in Gin context for request logging.
|
// RecordAPIRequest stores the upstream request metadata in Gin context for request logging.
|
||||||
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
|
func RecordAPIRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
|
||||||
if cfg == nil || !cfg.RequestLog {
|
if cfg == nil || !cfg.RequestLog {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -96,8 +99,8 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
|
|||||||
updateAggregatedRequest(ginCtx, attempts)
|
updateAggregatedRequest(ginCtx, attempts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// recordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
|
// RecordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
|
||||||
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
func RecordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||||
if cfg == nil || !cfg.RequestLog {
|
if cfg == nil || !cfg.RequestLog {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -122,8 +125,8 @@ func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status i
|
|||||||
updateAggregatedResponse(ginCtx, attempts)
|
updateAggregatedResponse(ginCtx, attempts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// recordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
|
// RecordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
|
||||||
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
|
func RecordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
|
||||||
if cfg == nil || !cfg.RequestLog || err == nil {
|
if cfg == nil || !cfg.RequestLog || err == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -147,8 +150,8 @@ func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error)
|
|||||||
updateAggregatedResponse(ginCtx, attempts)
|
updateAggregatedResponse(ginCtx, attempts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// appendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
|
// AppendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
|
||||||
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
|
func AppendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
|
||||||
if cfg == nil || !cfg.RequestLog {
|
if cfg == nil || !cfg.RequestLog {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -173,15 +176,157 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
|||||||
attempt.response.WriteString("Body:\n")
|
attempt.response.WriteString("Body:\n")
|
||||||
attempt.bodyStarted = true
|
attempt.bodyStarted = true
|
||||||
}
|
}
|
||||||
|
currentChunkIsSSEEvent := bytes.HasPrefix(data, []byte("event:"))
|
||||||
|
currentChunkIsSSEData := bytes.HasPrefix(data, []byte("data:"))
|
||||||
if attempt.bodyHasContent {
|
if attempt.bodyHasContent {
|
||||||
attempt.response.WriteString("\n\n")
|
separator := "\n\n"
|
||||||
|
if attempt.prevWasSSEEvent && currentChunkIsSSEData {
|
||||||
|
separator = "\n"
|
||||||
|
}
|
||||||
|
attempt.response.WriteString(separator)
|
||||||
}
|
}
|
||||||
attempt.response.WriteString(string(data))
|
attempt.response.WriteString(string(data))
|
||||||
attempt.bodyHasContent = true
|
attempt.bodyHasContent = true
|
||||||
|
attempt.prevWasSSEEvent = currentChunkIsSSEEvent
|
||||||
|
|
||||||
updateAggregatedResponse(ginCtx, attempts)
|
updateAggregatedResponse(ginCtx, attempts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordAPIWebsocketRequest stores an upstream websocket request event in Gin context.
|
||||||
|
func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info UpstreamRequestLog) {
|
||||||
|
if cfg == nil || !cfg.RequestLog {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ginCtx := ginContextFrom(ctx)
|
||||||
|
if ginCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
builder := &strings.Builder{}
|
||||||
|
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
|
builder.WriteString("Event: api.websocket.request\n")
|
||||||
|
if info.URL != "" {
|
||||||
|
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
|
||||||
|
}
|
||||||
|
if auth := formatAuthInfo(info); auth != "" {
|
||||||
|
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
|
||||||
|
}
|
||||||
|
builder.WriteString("Headers:\n")
|
||||||
|
writeHeaders(builder, info.Headers)
|
||||||
|
builder.WriteString("\nBody:\n")
|
||||||
|
if len(info.Body) > 0 {
|
||||||
|
builder.Write(info.Body)
|
||||||
|
} else {
|
||||||
|
builder.WriteString("<empty>")
|
||||||
|
}
|
||||||
|
builder.WriteString("\n")
|
||||||
|
|
||||||
|
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
|
||||||
|
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||||
|
if cfg == nil || !cfg.RequestLog {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ginCtx := ginContextFrom(ctx)
|
||||||
|
if ginCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
builder := &strings.Builder{}
|
||||||
|
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
|
builder.WriteString("Event: api.websocket.handshake\n")
|
||||||
|
if status > 0 {
|
||||||
|
builder.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||||
|
}
|
||||||
|
builder.WriteString("Headers:\n")
|
||||||
|
writeHeaders(builder, headers)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
|
||||||
|
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
|
||||||
|
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
|
||||||
|
if cfg == nil || !cfg.RequestLog {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ginCtx := ginContextFrom(ctx)
|
||||||
|
if ginCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
RecordAPIRequest(ctx, cfg, info)
|
||||||
|
RecordAPIResponseMetadata(ctx, cfg, status, headers)
|
||||||
|
AppendAPIResponseChunk(ctx, cfg, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebsocketUpgradeRequestURL converts a websocket URL back to its HTTP handshake URL for logging.
|
||||||
|
func WebsocketUpgradeRequestURL(rawURL string) string {
|
||||||
|
trimmedURL := strings.TrimSpace(rawURL)
|
||||||
|
if trimmedURL == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(trimmedURL)
|
||||||
|
if err != nil {
|
||||||
|
return trimmedURL
|
||||||
|
}
|
||||||
|
switch strings.ToLower(parsed.Scheme) {
|
||||||
|
case "ws":
|
||||||
|
parsed.Scheme = "http"
|
||||||
|
case "wss":
|
||||||
|
parsed.Scheme = "https"
|
||||||
|
}
|
||||||
|
return parsed.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendAPIWebsocketResponse stores an upstream websocket response frame in Gin context.
|
||||||
|
func AppendAPIWebsocketResponse(ctx context.Context, cfg *config.Config, payload []byte) {
|
||||||
|
if cfg == nil || !cfg.RequestLog {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(payload)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ginCtx := ginContextFrom(ctx)
|
||||||
|
if ginCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
markAPIResponseTimestamp(ginCtx)
|
||||||
|
|
||||||
|
builder := &strings.Builder{}
|
||||||
|
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
|
builder.WriteString("Event: api.websocket.response\n")
|
||||||
|
builder.Write(data)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
|
||||||
|
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordAPIWebsocketError stores an upstream websocket error event in Gin context.
|
||||||
|
func RecordAPIWebsocketError(ctx context.Context, cfg *config.Config, stage string, err error) {
|
||||||
|
if cfg == nil || !cfg.RequestLog || err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ginCtx := ginContextFrom(ctx)
|
||||||
|
if ginCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
markAPIResponseTimestamp(ginCtx)
|
||||||
|
|
||||||
|
builder := &strings.Builder{}
|
||||||
|
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
|
builder.WriteString("Event: api.websocket.error\n")
|
||||||
|
if trimmed := strings.TrimSpace(stage); trimmed != "" {
|
||||||
|
builder.WriteString(fmt.Sprintf("Stage: %s\n", trimmed))
|
||||||
|
}
|
||||||
|
builder.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
|
||||||
|
|
||||||
|
appendAPIWebsocketTimeline(ginCtx, []byte(builder.String()))
|
||||||
|
}
|
||||||
|
|
||||||
func ginContextFrom(ctx context.Context) *gin.Context {
|
func ginContextFrom(ctx context.Context) *gin.Context {
|
||||||
ginCtx, _ := ctx.Value("gin").(*gin.Context)
|
ginCtx, _ := ctx.Value("gin").(*gin.Context)
|
||||||
return ginCtx
|
return ginCtx
|
||||||
@@ -259,6 +404,40 @@ func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt)
|
|||||||
ginCtx.Set(apiResponseKey, []byte(builder.String()))
|
ginCtx.Set(apiResponseKey, []byte(builder.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendAPIWebsocketTimeline(ginCtx *gin.Context, chunk []byte) {
|
||||||
|
if ginCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(chunk)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing, exists := ginCtx.Get(apiWebsocketTimelineKey); exists {
|
||||||
|
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||||
|
combined := make([]byte, 0, len(existingBytes)+len(data)+2)
|
||||||
|
combined = append(combined, existingBytes...)
|
||||||
|
if !bytes.HasSuffix(existingBytes, []byte("\n")) {
|
||||||
|
combined = append(combined, '\n')
|
||||||
|
}
|
||||||
|
combined = append(combined, '\n')
|
||||||
|
combined = append(combined, data...)
|
||||||
|
ginCtx.Set(apiWebsocketTimelineKey, combined)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ginCtx.Set(apiWebsocketTimelineKey, bytes.Clone(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func markAPIResponseTimestamp(ginCtx *gin.Context) {
|
||||||
|
if ginCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := ginCtx.Get("API_RESPONSE_TIMESTAMP"); exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ginCtx.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
func writeHeaders(builder *strings.Builder, headers http.Header) {
|
func writeHeaders(builder *strings.Builder, headers http.Header) {
|
||||||
if builder == nil {
|
if builder == nil {
|
||||||
return
|
return
|
||||||
@@ -285,7 +464,7 @@ func writeHeaders(builder *strings.Builder, headers http.Header) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatAuthInfo(info upstreamRequestLog) string {
|
func formatAuthInfo(info UpstreamRequestLog) string {
|
||||||
var parts []string
|
var parts []string
|
||||||
if trimmed := strings.TrimSpace(info.Provider); trimmed != "" {
|
if trimmed := strings.TrimSpace(info.Provider); trimmed != "" {
|
||||||
parts = append(parts, fmt.Sprintf("provider=%s", trimmed))
|
parts = append(parts, fmt.Sprintf("provider=%s", trimmed))
|
||||||
@@ -321,7 +500,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
|
|||||||
return strings.Join(parts, ", ")
|
return strings.Join(parts, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func summarizeErrorBody(contentType string, body []byte) string {
|
func SummarizeErrorBody(contentType string, body []byte) string {
|
||||||
isHTML := strings.Contains(strings.ToLower(contentType), "text/html")
|
isHTML := strings.Contains(strings.ToLower(contentType), "text/html")
|
||||||
if !isHTML {
|
if !isHTML {
|
||||||
trimmed := bytes.TrimSpace(bytes.ToLower(body))
|
trimmed := bytes.TrimSpace(bytes.ToLower(body))
|
||||||
@@ -379,7 +558,7 @@ func extractJSONErrorMessage(body []byte) string {
|
|||||||
|
|
||||||
// logWithRequestID returns a logrus Entry with request_id field populated from context.
|
// logWithRequestID returns a logrus Entry with request_id field populated from context.
|
||||||
// If no request ID is found in context, it returns the standard logger.
|
// If no request ID is found in context, it returns the standard logger.
|
||||||
func logWithRequestID(ctx context.Context) *log.Entry {
|
func LogWithRequestID(ctx context.Context) *log.Entry {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return log.NewEntry(log.StandardLogger())
|
return log.NewEntry(log.StandardLogger())
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -11,12 +11,12 @@ import (
|
|||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
// ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||||
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
||||||
// and restricts matches to the given protocol when supplied. Defaults are checked
|
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||||
// against the original payload when provided. requestedModel carries the client-visible
|
// against the original payload when provided. requestedModel carries the client-visible
|
||||||
// model name before alias resolution so payload rules can target aliases precisely.
|
// model name before alias resolution so payload rules can target aliases precisely.
|
||||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||||
if cfg == nil || len(payload) == 0 {
|
if cfg == nil || len(payload) == 0 {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
@@ -244,7 +244,7 @@ func payloadRawValue(value any) ([]byte, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||||
fallback = strings.TrimSpace(fallback)
|
fallback = strings.TrimSpace(fallback)
|
||||||
if len(opts.Metadata) == 0 {
|
if len(opts.Metadata) == 0 {
|
||||||
return fallback
|
return fallback
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -19,7 +19,7 @@ var (
|
|||||||
httpClientCacheMutex sync.RWMutex
|
httpClientCacheMutex sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
// NewProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||||
// 1. Use auth.ProxyURL if configured (highest priority)
|
// 1. Use auth.ProxyURL if configured (highest priority)
|
||||||
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
||||||
// 3. Use RoundTripper from context if neither are configured
|
// 3. Use RoundTripper from context if neither are configured
|
||||||
@@ -34,7 +34,7 @@ var (
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *http.Client: An HTTP client with configured proxy or transport
|
// - *http.Client: An HTTP client with configured proxy or transport
|
||||||
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||||
// Priority 1: Use auth.ProxyURL if configured
|
// Priority 1: Use auth.ProxyURL if configured
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -46,23 +46,18 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build cache key from proxy URL (empty string for no proxy)
|
// If we have a proxy URL configured, try cache first to reuse TCP/TLS connections.
|
||||||
cacheKey := proxyURL
|
if proxyURL != "" {
|
||||||
|
httpClientCacheMutex.RLock()
|
||||||
// Check cache first
|
if cachedClient, ok := httpClientCache[proxyURL]; ok {
|
||||||
httpClientCacheMutex.RLock()
|
httpClientCacheMutex.RUnlock()
|
||||||
if cachedClient, ok := httpClientCache[cacheKey]; ok {
|
if timeout > 0 {
|
||||||
httpClientCacheMutex.RUnlock()
|
return &http.Client{Transport: cachedClient.Transport, Timeout: timeout}
|
||||||
// Return a wrapper with the requested timeout but shared transport
|
|
||||||
if timeout > 0 {
|
|
||||||
return &http.Client{
|
|
||||||
Transport: cachedClient.Transport,
|
|
||||||
Timeout: timeout,
|
|
||||||
}
|
}
|
||||||
|
return cachedClient
|
||||||
}
|
}
|
||||||
return cachedClient
|
httpClientCacheMutex.RUnlock()
|
||||||
}
|
}
|
||||||
httpClientCacheMutex.RUnlock()
|
|
||||||
|
|
||||||
// Create new client
|
// Create new client
|
||||||
httpClient := &http.Client{}
|
httpClient := &http.Client{}
|
||||||
@@ -77,7 +72,7 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
httpClient.Transport = transport
|
httpClient.Transport = transport
|
||||||
// Cache the client
|
// Cache the client
|
||||||
httpClientCacheMutex.Lock()
|
httpClientCacheMutex.Lock()
|
||||||
httpClientCache[cacheKey] = httpClient
|
httpClientCache[proxyURL] = httpClient
|
||||||
httpClientCacheMutex.Unlock()
|
httpClientCacheMutex.Unlock()
|
||||||
return httpClient
|
return httpClient
|
||||||
}
|
}
|
||||||
@@ -90,13 +85,6 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
httpClient.Transport = rt
|
httpClient.Transport = rt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the client for no-proxy case
|
|
||||||
if proxyURL == "" {
|
|
||||||
httpClientCacheMutex.Lock()
|
|
||||||
httpClientCache[cacheKey] = httpClient
|
|
||||||
httpClientCacheMutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return httpClient
|
return httpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
|
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
client := newProxyAwareHTTPClient(
|
client := NewProxyAwareHTTPClient(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||||
&cliproxyauth.Auth{ProxyURL: "direct"},
|
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||||
92
internal/runtime/executor/helps/session_id_cache.go
Normal file
92
internal/runtime/executor/helps/session_id_cache.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package helps
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sessionIDCacheEntry struct {
|
||||||
|
value string
|
||||||
|
expire time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
sessionIDCache = make(map[string]sessionIDCacheEntry)
|
||||||
|
sessionIDCacheMu sync.RWMutex
|
||||||
|
sessionIDCacheCleanupOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sessionIDTTL = time.Hour
|
||||||
|
sessionIDCacheCleanupPeriod = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
func startSessionIDCacheCleanup() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(sessionIDCacheCleanupPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
purgeExpiredSessionIDs()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func purgeExpiredSessionIDs() {
|
||||||
|
now := time.Now()
|
||||||
|
sessionIDCacheMu.Lock()
|
||||||
|
for key, entry := range sessionIDCache {
|
||||||
|
if !entry.expire.After(now) {
|
||||||
|
delete(sessionIDCache, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sessionIDCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionIDCacheKey(apiKey string) string {
|
||||||
|
sum := sha256.Sum256([]byte(apiKey))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedSessionID returns a stable session UUID per apiKey, refreshing the TTL on each access.
|
||||||
|
func CachedSessionID(apiKey string) string {
|
||||||
|
if apiKey == "" {
|
||||||
|
return uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionIDCacheCleanupOnce.Do(startSessionIDCacheCleanup)
|
||||||
|
|
||||||
|
key := sessionIDCacheKey(apiKey)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
sessionIDCacheMu.RLock()
|
||||||
|
entry, ok := sessionIDCache[key]
|
||||||
|
valid := ok && entry.value != "" && entry.expire.After(now)
|
||||||
|
sessionIDCacheMu.RUnlock()
|
||||||
|
if valid {
|
||||||
|
sessionIDCacheMu.Lock()
|
||||||
|
entry = sessionIDCache[key]
|
||||||
|
if entry.value != "" && entry.expire.After(now) {
|
||||||
|
entry.expire = now.Add(sessionIDTTL)
|
||||||
|
sessionIDCache[key] = entry
|
||||||
|
sessionIDCacheMu.Unlock()
|
||||||
|
return entry.value
|
||||||
|
}
|
||||||
|
sessionIDCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
newID := uuid.New().String()
|
||||||
|
|
||||||
|
sessionIDCacheMu.Lock()
|
||||||
|
entry, ok = sessionIDCache[key]
|
||||||
|
if !ok || entry.value == "" || !entry.expire.After(now) {
|
||||||
|
entry.value = newID
|
||||||
|
}
|
||||||
|
entry.expire = now.Add(sessionIDTTL)
|
||||||
|
sessionIDCache[key] = entry
|
||||||
|
sessionIDCacheMu.Unlock()
|
||||||
|
return entry.value
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -11,100 +9,80 @@ import (
|
|||||||
"github.com/tiktoken-go/tokenizer"
|
"github.com/tiktoken-go/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenizerCache stores tokenizer instances to avoid repeated creation
|
// tokenizerCache stores tokenizer instances to avoid repeated creation.
|
||||||
var tokenizerCache sync.Map
|
var tokenizerCache sync.Map
|
||||||
|
|
||||||
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
|
type adjustedTokenizer struct {
|
||||||
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
|
tokenizer.Codec
|
||||||
type TokenizerWrapper struct {
|
adjustmentFactor float64
|
||||||
Codec tokenizer.Codec
|
|
||||||
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count returns the token count with adjustment factor applied
|
func (tw *adjustedTokenizer) Count(text string) (int, error) {
|
||||||
func (tw *TokenizerWrapper) Count(text string) (int, error) {
|
|
||||||
count, err := tw.Codec.Count(text)
|
count, err := tw.Codec.Count(text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
|
if tw.adjustmentFactor > 0 && tw.adjustmentFactor != 1.0 {
|
||||||
return int(float64(count) * tw.AdjustmentFactor), nil
|
return int(float64(count) * tw.adjustmentFactor), nil
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTokenizer returns a cached tokenizer for the given model.
|
// TokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||||
// This improves performance by avoiding repeated tokenizer creation.
|
// For Claude-like models, it applies an adjustment factor since tiktoken may underestimate token counts.
|
||||||
func getTokenizer(model string) (*TokenizerWrapper, error) {
|
func TokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||||
// Check cache first
|
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||||
if cached, ok := tokenizerCache.Load(model); ok {
|
if cached, ok := tokenizerCache.Load(sanitized); ok {
|
||||||
return cached.(*TokenizerWrapper), nil
|
return cached.(tokenizer.Codec), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache miss, create new tokenizer
|
enc, err := tokenizerForModel(sanitized)
|
||||||
wrapper, err := tokenizerForModel(model)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store in cache (use LoadOrStore to handle race conditions)
|
actual, _ := tokenizerCache.LoadOrStore(sanitized, enc)
|
||||||
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
|
return actual.(tokenizer.Codec), nil
|
||||||
return actual.(*TokenizerWrapper), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
func tokenizerForModel(sanitized string) (tokenizer.Codec, error) {
|
||||||
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
|
if sanitized == "" {
|
||||||
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
}
|
||||||
|
|
||||||
// Claude models use cl100k_base with 1.1 adjustment factor
|
// Claude models use cl100k_base with an adjustment factor because tiktoken may underestimate.
|
||||||
// because tiktoken may underestimate Claude's actual token count
|
|
||||||
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
|
return &adjustedTokenizer{Codec: enc, adjustmentFactor: 1.1}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var enc tokenizer.Codec
|
|
||||||
var err error
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case sanitized == "":
|
|
||||||
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
|
||||||
case strings.HasPrefix(sanitized, "gpt-5.2"):
|
|
||||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
|
||||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
|
||||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
|
||||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
return tokenizer.ForModel(tokenizer.GPT5)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||||
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
return tokenizer.ForModel(tokenizer.GPT41)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||||
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
|
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||||
enc, err = tokenizer.ForModel(tokenizer.GPT4)
|
return tokenizer.ForModel(tokenizer.GPT4)
|
||||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||||
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
|
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||||
case strings.HasPrefix(sanitized, "o1"):
|
case strings.HasPrefix(sanitized, "o1"):
|
||||||
enc, err = tokenizer.ForModel(tokenizer.O1)
|
return tokenizer.ForModel(tokenizer.O1)
|
||||||
case strings.HasPrefix(sanitized, "o3"):
|
case strings.HasPrefix(sanitized, "o3"):
|
||||||
enc, err = tokenizer.ForModel(tokenizer.O3)
|
return tokenizer.ForModel(tokenizer.O3)
|
||||||
case strings.HasPrefix(sanitized, "o4"):
|
case strings.HasPrefix(sanitized, "o4"):
|
||||||
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
|
return tokenizer.ForModel(tokenizer.O4Mini)
|
||||||
default:
|
default:
|
||||||
enc, err = tokenizer.Get(tokenizer.O200kBase)
|
return tokenizer.Get(tokenizer.O200kBase)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
// CountOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||||
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
func CountOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||||
if enc == nil {
|
if enc == nil {
|
||||||
return 0, fmt.Errorf("encoder is nil")
|
return 0, fmt.Errorf("encoder is nil")
|
||||||
}
|
}
|
||||||
@@ -128,22 +106,15 @@ func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count text tokens
|
|
||||||
count, err := enc.Count(joined)
|
count, err := enc.Count(joined)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
return int64(count), nil
|
||||||
// Extract and add image tokens from placeholders
|
|
||||||
imageTokens := extractImageTokens(joined)
|
|
||||||
|
|
||||||
return int64(count) + int64(imageTokens), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
|
// CountClaudeChatTokens approximates prompt tokens for Claude API chat payloads.
|
||||||
// This handles Claude's message format with system, messages, and tools.
|
func CountClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||||
// Image tokens are estimated based on image dimensions when available.
|
|
||||||
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
|
||||||
if enc == nil {
|
if enc == nil {
|
||||||
return 0, fmt.Errorf("encoder is nil")
|
return 0, fmt.Errorf("encoder is nil")
|
||||||
}
|
}
|
||||||
@@ -153,185 +124,25 @@ func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error)
|
|||||||
|
|
||||||
root := gjson.ParseBytes(payload)
|
root := gjson.ParseBytes(payload)
|
||||||
segments := make([]string, 0, 32)
|
segments := make([]string, 0, 32)
|
||||||
|
imageTokens := 0
|
||||||
|
|
||||||
// Collect system prompt (can be string or array of content blocks)
|
collectClaudeContent(root.Get("system"), &segments, &imageTokens)
|
||||||
collectClaudeSystem(root.Get("system"), &segments)
|
collectClaudeMessages(root.Get("messages"), &segments, &imageTokens)
|
||||||
|
|
||||||
// Collect messages
|
|
||||||
collectClaudeMessages(root.Get("messages"), &segments)
|
|
||||||
|
|
||||||
// Collect tools
|
|
||||||
collectClaudeTools(root.Get("tools"), &segments)
|
collectClaudeTools(root.Get("tools"), &segments)
|
||||||
|
|
||||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||||
if joined == "" {
|
if joined == "" {
|
||||||
return 0, nil
|
return int64(imageTokens), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count text tokens
|
|
||||||
count, err := enc.Count(joined)
|
count, err := enc.Count(joined)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
return int64(count + imageTokens), nil
|
||||||
// Extract and add image tokens from placeholders
|
|
||||||
imageTokens := extractImageTokens(joined)
|
|
||||||
|
|
||||||
return int64(count) + int64(imageTokens), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
|
// BuildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||||
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
|
func BuildOpenAIUsageJSON(count int64) []byte {
|
||||||
|
|
||||||
// extractImageTokens extracts image token estimates from placeholder text.
|
|
||||||
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
|
|
||||||
func extractImageTokens(text string) int {
|
|
||||||
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
|
|
||||||
total := 0
|
|
||||||
for _, match := range matches {
|
|
||||||
if len(match) > 1 {
|
|
||||||
if tokens, err := strconv.Atoi(match[1]); err == nil {
|
|
||||||
total += tokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
|
|
||||||
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
|
||||||
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
|
||||||
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
|
||||||
func estimateImageTokens(width, height float64) int {
|
|
||||||
if width <= 0 || height <= 0 {
|
|
||||||
// No valid dimensions, use default estimate (medium-sized image)
|
|
||||||
return 1000
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens := int(width * height / 750)
|
|
||||||
|
|
||||||
// Apply bounds
|
|
||||||
if tokens < 85 {
|
|
||||||
tokens = 85
|
|
||||||
}
|
|
||||||
if tokens > 1590 {
|
|
||||||
tokens = 1590
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// collectClaudeSystem extracts text from Claude's system field.
|
|
||||||
// System can be a string or an array of content blocks.
|
|
||||||
func collectClaudeSystem(system gjson.Result, segments *[]string) {
|
|
||||||
if !system.Exists() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if system.Type == gjson.String {
|
|
||||||
addIfNotEmpty(segments, system.String())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if system.IsArray() {
|
|
||||||
system.ForEach(func(_, block gjson.Result) bool {
|
|
||||||
blockType := block.Get("type").String()
|
|
||||||
if blockType == "text" || blockType == "" {
|
|
||||||
addIfNotEmpty(segments, block.Get("text").String())
|
|
||||||
}
|
|
||||||
// Also handle plain string blocks
|
|
||||||
if block.Type == gjson.String {
|
|
||||||
addIfNotEmpty(segments, block.String())
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// collectClaudeMessages extracts text from Claude's messages array.
|
|
||||||
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
|
|
||||||
if !messages.Exists() || !messages.IsArray() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
messages.ForEach(func(_, message gjson.Result) bool {
|
|
||||||
addIfNotEmpty(segments, message.Get("role").String())
|
|
||||||
collectClaudeContent(message.Get("content"), segments)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// collectClaudeContent extracts text from Claude's content field.
|
|
||||||
// Content can be a string or an array of content blocks.
|
|
||||||
// For images, estimates token count based on dimensions when available.
|
|
||||||
func collectClaudeContent(content gjson.Result, segments *[]string) {
|
|
||||||
if !content.Exists() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if content.Type == gjson.String {
|
|
||||||
addIfNotEmpty(segments, content.String())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if content.IsArray() {
|
|
||||||
content.ForEach(func(_, part gjson.Result) bool {
|
|
||||||
partType := part.Get("type").String()
|
|
||||||
switch partType {
|
|
||||||
case "text":
|
|
||||||
addIfNotEmpty(segments, part.Get("text").String())
|
|
||||||
case "image":
|
|
||||||
// Estimate image tokens based on dimensions if available
|
|
||||||
source := part.Get("source")
|
|
||||||
if source.Exists() {
|
|
||||||
width := source.Get("width").Float()
|
|
||||||
height := source.Get("height").Float()
|
|
||||||
if width > 0 && height > 0 {
|
|
||||||
tokens := estimateImageTokens(width, height)
|
|
||||||
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
|
|
||||||
} else {
|
|
||||||
// No dimensions available, use default estimate
|
|
||||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No source info, use default estimate
|
|
||||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
|
||||||
}
|
|
||||||
case "tool_use":
|
|
||||||
addIfNotEmpty(segments, part.Get("id").String())
|
|
||||||
addIfNotEmpty(segments, part.Get("name").String())
|
|
||||||
if input := part.Get("input"); input.Exists() {
|
|
||||||
addIfNotEmpty(segments, input.Raw)
|
|
||||||
}
|
|
||||||
case "tool_result":
|
|
||||||
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
|
||||||
collectClaudeContent(part.Get("content"), segments)
|
|
||||||
case "thinking":
|
|
||||||
addIfNotEmpty(segments, part.Get("thinking").String())
|
|
||||||
default:
|
|
||||||
// For unknown types, try to extract any text content
|
|
||||||
if part.Type == gjson.String {
|
|
||||||
addIfNotEmpty(segments, part.String())
|
|
||||||
} else if part.Type == gjson.JSON {
|
|
||||||
addIfNotEmpty(segments, part.Raw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// collectClaudeTools extracts text from Claude's tools array.
|
|
||||||
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
|
||||||
if !tools.Exists() || !tools.IsArray() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
|
||||||
addIfNotEmpty(segments, tool.Get("name").String())
|
|
||||||
addIfNotEmpty(segments, tool.Get("description").String())
|
|
||||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
|
||||||
addIfNotEmpty(segments, inputSchema.Raw)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
|
||||||
func buildOpenAIUsageJSON(count int64) []byte {
|
|
||||||
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
|
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -390,6 +201,10 @@ func collectOpenAIContent(content gjson.Result, segments *[]string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CollectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||||
|
collectOpenAIContent(content, segments)
|
||||||
|
}
|
||||||
|
|
||||||
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
|
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
|
||||||
if !calls.Exists() || !calls.IsArray() {
|
if !calls.Exists() || !calls.IsArray() {
|
||||||
return
|
return
|
||||||
@@ -487,6 +302,98 @@ func appendToolPayload(tool gjson.Result, segments *[]string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func collectClaudeMessages(messages gjson.Result, segments *[]string, imageTokens *int) {
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages.ForEach(func(_, message gjson.Result) bool {
|
||||||
|
addIfNotEmpty(segments, message.Get("role").String())
|
||||||
|
collectClaudeContent(message.Get("content"), segments, imageTokens)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectClaudeContent(content gjson.Result, segments *[]string, imageTokens *int) {
|
||||||
|
if !content.Exists() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, content.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.IsArray() {
|
||||||
|
content.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
switch partType {
|
||||||
|
case "text":
|
||||||
|
addIfNotEmpty(segments, part.Get("text").String())
|
||||||
|
case "image":
|
||||||
|
source := part.Get("source")
|
||||||
|
width := source.Get("width").Float()
|
||||||
|
height := source.Get("height").Float()
|
||||||
|
if imageTokens != nil {
|
||||||
|
*imageTokens += estimateImageTokens(width, height)
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
addIfNotEmpty(segments, part.Get("id").String())
|
||||||
|
addIfNotEmpty(segments, part.Get("name").String())
|
||||||
|
if input := part.Get("input"); input.Exists() {
|
||||||
|
addIfNotEmpty(segments, input.Raw)
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||||
|
collectClaudeContent(part.Get("content"), segments, imageTokens)
|
||||||
|
case "thinking":
|
||||||
|
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||||
|
default:
|
||||||
|
if part.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, part.String())
|
||||||
|
} else if part.Type == gjson.JSON {
|
||||||
|
addIfNotEmpty(segments, part.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.Type == gjson.JSON {
|
||||||
|
addIfNotEmpty(segments, content.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||||
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||||
|
addIfNotEmpty(segments, tool.Get("name").String())
|
||||||
|
addIfNotEmpty(segments, tool.Get("description").String())
|
||||||
|
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||||
|
addIfNotEmpty(segments, inputSchema.Raw)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||||
|
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||||
|
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||||
|
func estimateImageTokens(width, height float64) int {
|
||||||
|
if width <= 0 || height <= 0 {
|
||||||
|
// No valid dimensions, use default estimate (medium-sized image).
|
||||||
|
return 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := int(width * height / 750)
|
||||||
|
if tokens < 85 {
|
||||||
|
return 85
|
||||||
|
}
|
||||||
|
if tokens > 1590 {
|
||||||
|
return 1590
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
func addIfNotEmpty(segments *[]string, value string) {
|
func addIfNotEmpty(segments *[]string, value string) {
|
||||||
if segments == nil {
|
if segments == nil {
|
||||||
return
|
return
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
type usageReporter struct {
|
type UsageReporter struct {
|
||||||
provider string
|
provider string
|
||||||
model string
|
model string
|
||||||
authID string
|
authID string
|
||||||
@@ -26,9 +26,9 @@ type usageReporter struct {
|
|||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
|
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
|
||||||
apiKey := apiKeyFromContext(ctx)
|
apiKey := APIKeyFromContext(ctx)
|
||||||
reporter := &usageReporter{
|
reporter := &UsageReporter{
|
||||||
provider: provider,
|
provider: provider,
|
||||||
model: model,
|
model: model,
|
||||||
requestedAt: time.Now(),
|
requestedAt: time.Now(),
|
||||||
@@ -42,24 +42,24 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox
|
|||||||
return reporter
|
return reporter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
|
func (r *UsageReporter) Publish(ctx context.Context, detail usage.Detail) {
|
||||||
r.publishWithOutcome(ctx, detail, false)
|
r.publishWithOutcome(ctx, detail, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageReporter) publishFailure(ctx context.Context) {
|
func (r *UsageReporter) PublishFailure(ctx context.Context) {
|
||||||
r.publishWithOutcome(ctx, usage.Detail{}, true)
|
r.publishWithOutcome(ctx, usage.Detail{}, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
|
func (r *UsageReporter) TrackFailure(ctx context.Context, errPtr *error) {
|
||||||
if r == nil || errPtr == nil {
|
if r == nil || errPtr == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if *errPtr != nil {
|
if *errPtr != nil {
|
||||||
r.publishFailure(ctx)
|
r.PublishFailure(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
|
func (r *UsageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -69,9 +69,6 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
|
|||||||
detail.TotalTokens = total
|
detail.TotalTokens = total
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if detail.InputTokens == 0 && detail.OutputTokens == 0 && detail.ReasoningTokens == 0 && detail.CachedTokens == 0 && detail.TotalTokens == 0 && !failed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.once.Do(func() {
|
r.once.Do(func() {
|
||||||
usage.PublishRecord(ctx, r.buildRecord(detail, failed))
|
usage.PublishRecord(ctx, r.buildRecord(detail, failed))
|
||||||
})
|
})
|
||||||
@@ -81,7 +78,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
|
|||||||
// It is safe to call multiple times; only the first call wins due to once.Do.
|
// It is safe to call multiple times; only the first call wins due to once.Do.
|
||||||
// This is used to ensure request counting even when upstream responses do not
|
// This is used to ensure request counting even when upstream responses do not
|
||||||
// include any usage fields (tokens), especially for streaming paths.
|
// include any usage fields (tokens), especially for streaming paths.
|
||||||
func (r *usageReporter) ensurePublished(ctx context.Context) {
|
func (r *UsageReporter) EnsurePublished(ctx context.Context) {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -90,7 +87,7 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return usage.Record{Detail: detail, Failed: failed}
|
return usage.Record{Detail: detail, Failed: failed}
|
||||||
}
|
}
|
||||||
@@ -108,7 +105,7 @@ func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Reco
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageReporter) latency() time.Duration {
|
func (r *UsageReporter) latency() time.Duration {
|
||||||
if r == nil || r.requestedAt.IsZero() {
|
if r == nil || r.requestedAt.IsZero() {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -119,7 +116,7 @@ func (r *usageReporter) latency() time.Duration {
|
|||||||
return latency
|
return latency
|
||||||
}
|
}
|
||||||
|
|
||||||
func apiKeyFromContext(ctx context.Context) string {
|
func APIKeyFromContext(ctx context.Context) string {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -184,7 +181,7 @@ func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseCodexUsage(data []byte) (usage.Detail, bool) {
|
func ParseCodexUsage(data []byte) (usage.Detail, bool) {
|
||||||
usageNode := gjson.ParseBytes(data).Get("response.usage")
|
usageNode := gjson.ParseBytes(data).Get("response.usage")
|
||||||
if !usageNode.Exists() {
|
if !usageNode.Exists() {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
@@ -203,7 +200,7 @@ func parseCodexUsage(data []byte) (usage.Detail, bool) {
|
|||||||
return detail, true
|
return detail, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOpenAIUsage(data []byte) usage.Detail {
|
func ParseOpenAIUsage(data []byte) usage.Detail {
|
||||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||||
if !usageNode.Exists() {
|
if !usageNode.Exists() {
|
||||||
return usage.Detail{}
|
return usage.Detail{}
|
||||||
@@ -238,7 +235,7 @@ func parseOpenAIUsage(data []byte) usage.Detail {
|
|||||||
return detail
|
return detail
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
func ParseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
payload := jsonPayload(line)
|
payload := jsonPayload(line)
|
||||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
@@ -247,59 +244,40 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
|||||||
if !usageNode.Exists() {
|
if !usageNode.Exists() {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inputNode := usageNode.Get("prompt_tokens")
|
||||||
|
if !inputNode.Exists() {
|
||||||
|
inputNode = usageNode.Get("input_tokens")
|
||||||
|
}
|
||||||
|
outputNode := usageNode.Get("completion_tokens")
|
||||||
|
if !outputNode.Exists() {
|
||||||
|
outputNode = usageNode.Get("output_tokens")
|
||||||
|
}
|
||||||
detail := usage.Detail{
|
detail := usage.Detail{
|
||||||
InputTokens: usageNode.Get("prompt_tokens").Int(),
|
InputTokens: inputNode.Int(),
|
||||||
OutputTokens: usageNode.Get("completion_tokens").Int(),
|
OutputTokens: outputNode.Int(),
|
||||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||||
}
|
}
|
||||||
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
|
|
||||||
|
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
|
||||||
|
if !cached.Exists() {
|
||||||
|
cached = usageNode.Get("input_tokens_details.cached_tokens")
|
||||||
|
}
|
||||||
|
if cached.Exists() {
|
||||||
detail.CachedTokens = cached.Int()
|
detail.CachedTokens = cached.Int()
|
||||||
}
|
}
|
||||||
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
|
||||||
|
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
|
||||||
|
if !reasoning.Exists() {
|
||||||
|
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
|
||||||
|
}
|
||||||
|
if reasoning.Exists() {
|
||||||
detail.ReasoningTokens = reasoning.Int()
|
detail.ReasoningTokens = reasoning.Int()
|
||||||
}
|
}
|
||||||
return detail, true
|
return detail, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
|
func ParseClaudeUsage(data []byte) usage.Detail {
|
||||||
detail := usage.Detail{
|
|
||||||
InputTokens: usageNode.Get("input_tokens").Int(),
|
|
||||||
OutputTokens: usageNode.Get("output_tokens").Int(),
|
|
||||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
|
||||||
}
|
|
||||||
if detail.TotalTokens == 0 {
|
|
||||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
|
|
||||||
}
|
|
||||||
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
|
|
||||||
detail.CachedTokens = cached.Int()
|
|
||||||
}
|
|
||||||
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
|
||||||
detail.ReasoningTokens = reasoning.Int()
|
|
||||||
}
|
|
||||||
return detail
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
|
|
||||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
|
||||||
if !usageNode.Exists() {
|
|
||||||
return usage.Detail{}
|
|
||||||
}
|
|
||||||
return parseOpenAIResponsesUsageDetail(usageNode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
|
|
||||||
payload := jsonPayload(line)
|
|
||||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
|
||||||
return usage.Detail{}, false
|
|
||||||
}
|
|
||||||
usageNode := gjson.GetBytes(payload, "usage")
|
|
||||||
if !usageNode.Exists() {
|
|
||||||
return usage.Detail{}, false
|
|
||||||
}
|
|
||||||
return parseOpenAIResponsesUsageDetail(usageNode), true
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseClaudeUsage(data []byte) usage.Detail {
|
|
||||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||||
if !usageNode.Exists() {
|
if !usageNode.Exists() {
|
||||||
return usage.Detail{}
|
return usage.Detail{}
|
||||||
@@ -317,7 +295,7 @@ func parseClaudeUsage(data []byte) usage.Detail {
|
|||||||
return detail
|
return detail
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
|
func ParseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
payload := jsonPayload(line)
|
payload := jsonPayload(line)
|
||||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
@@ -352,7 +330,7 @@ func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
|
|||||||
return detail
|
return detail
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiCLIUsage(data []byte) usage.Detail {
|
func ParseGeminiCLIUsage(data []byte) usage.Detail {
|
||||||
usageNode := gjson.ParseBytes(data)
|
usageNode := gjson.ParseBytes(data)
|
||||||
node := usageNode.Get("response.usageMetadata")
|
node := usageNode.Get("response.usageMetadata")
|
||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
@@ -364,7 +342,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
|
|||||||
return parseGeminiFamilyUsageDetail(node)
|
return parseGeminiFamilyUsageDetail(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiUsage(data []byte) usage.Detail {
|
func ParseGeminiUsage(data []byte) usage.Detail {
|
||||||
usageNode := gjson.ParseBytes(data)
|
usageNode := gjson.ParseBytes(data)
|
||||||
node := usageNode.Get("usageMetadata")
|
node := usageNode.Get("usageMetadata")
|
||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
@@ -376,7 +354,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
|
|||||||
return parseGeminiFamilyUsageDetail(node)
|
return parseGeminiFamilyUsageDetail(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
func ParseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
payload := jsonPayload(line)
|
payload := jsonPayload(line)
|
||||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
@@ -391,7 +369,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
|||||||
return parseGeminiFamilyUsageDetail(node), true
|
return parseGeminiFamilyUsageDetail(node), true
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
func ParseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
payload := jsonPayload(line)
|
payload := jsonPayload(line)
|
||||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
@@ -406,7 +384,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
|||||||
return parseGeminiFamilyUsageDetail(node), true
|
return parseGeminiFamilyUsageDetail(node), true
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAntigravityUsage(data []byte) usage.Detail {
|
func ParseAntigravityUsage(data []byte) usage.Detail {
|
||||||
usageNode := gjson.ParseBytes(data)
|
usageNode := gjson.ParseBytes(data)
|
||||||
node := usageNode.Get("response.usageMetadata")
|
node := usageNode.Get("response.usageMetadata")
|
||||||
if !node.Exists() {
|
if !node.Exists() {
|
||||||
@@ -421,7 +399,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
|
|||||||
return parseGeminiFamilyUsageDetail(node)
|
return parseGeminiFamilyUsageDetail(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
func ParseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||||
payload := jsonPayload(line)
|
payload := jsonPayload(line)
|
||||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||||
return usage.Detail{}, false
|
return usage.Detail{}, false
|
||||||
@@ -590,6 +568,10 @@ func isStopChunkWithoutUsage(jsonBytes []byte) bool {
|
|||||||
return !hasUsageMetadata(jsonBytes)
|
return !hasUsageMetadata(jsonBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func JSONPayload(line []byte) []byte {
|
||||||
|
return jsonPayload(line)
|
||||||
|
}
|
||||||
|
|
||||||
func jsonPayload(line []byte) []byte {
|
func jsonPayload(line []byte) []byte {
|
||||||
trimmed := bytes.TrimSpace(line)
|
trimmed := bytes.TrimSpace(line)
|
||||||
if len(trimmed) == 0 {
|
if len(trimmed) == 0 {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||||
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
||||||
detail := parseOpenAIUsage(data)
|
detail := ParseOpenAIUsage(data)
|
||||||
if detail.InputTokens != 1 {
|
if detail.InputTokens != 1 {
|
||||||
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
|
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@ func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
|||||||
|
|
||||||
func TestParseOpenAIUsageResponses(t *testing.T) {
|
func TestParseOpenAIUsageResponses(t *testing.T) {
|
||||||
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
|
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
|
||||||
detail := parseOpenAIUsage(data)
|
detail := ParseOpenAIUsage(data)
|
||||||
if detail.InputTokens != 10 {
|
if detail.InputTokens != 10 {
|
||||||
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
|
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
|
||||||
}
|
}
|
||||||
@@ -48,7 +48,7 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
||||||
reporter := &usageReporter{
|
reporter := &UsageReporter{
|
||||||
provider: "openai",
|
provider: "openai",
|
||||||
model: "gpt-5.4",
|
model: "gpt-5.4",
|
||||||
requestedAt: time.Now().Add(-1500 * time.Millisecond),
|
requestedAt: time.Now().Add(-1500 * time.Millisecond),
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
@@ -49,7 +49,7 @@ func userIDCacheKey(apiKey string) string {
|
|||||||
return hex.EncodeToString(sum[:])
|
return hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func cachedUserID(apiKey string) string {
|
func CachedUserID(apiKey string) string {
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return generateFakeUserID()
|
return generateFakeUserID()
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package executor
|
package helps
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@@ -14,8 +14,8 @@ func resetUserIDCache() {
|
|||||||
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
||||||
resetUserIDCache()
|
resetUserIDCache()
|
||||||
|
|
||||||
first := cachedUserID("api-key-1")
|
first := CachedUserID("api-key-1")
|
||||||
second := cachedUserID("api-key-1")
|
second := CachedUserID("api-key-1")
|
||||||
|
|
||||||
if first == "" {
|
if first == "" {
|
||||||
t.Fatal("expected generated user_id to be non-empty")
|
t.Fatal("expected generated user_id to be non-empty")
|
||||||
@@ -28,7 +28,7 @@ func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
|||||||
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||||
resetUserIDCache()
|
resetUserIDCache()
|
||||||
|
|
||||||
expiredID := cachedUserID("api-key-expired")
|
expiredID := CachedUserID("api-key-expired")
|
||||||
cacheKey := userIDCacheKey("api-key-expired")
|
cacheKey := userIDCacheKey("api-key-expired")
|
||||||
userIDCacheMu.Lock()
|
userIDCacheMu.Lock()
|
||||||
userIDCache[cacheKey] = userIDCacheEntry{
|
userIDCache[cacheKey] = userIDCacheEntry{
|
||||||
@@ -37,7 +37,7 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
userIDCacheMu.Unlock()
|
userIDCacheMu.Unlock()
|
||||||
|
|
||||||
newID := cachedUserID("api-key-expired")
|
newID := CachedUserID("api-key-expired")
|
||||||
if newID == expiredID {
|
if newID == expiredID {
|
||||||
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
|
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
|
||||||
}
|
}
|
||||||
@@ -49,8 +49,8 @@ func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
|||||||
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
|
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
|
||||||
resetUserIDCache()
|
resetUserIDCache()
|
||||||
|
|
||||||
first := cachedUserID("api-key-1")
|
first := CachedUserID("api-key-1")
|
||||||
second := cachedUserID("api-key-2")
|
second := CachedUserID("api-key-2")
|
||||||
|
|
||||||
if first == second {
|
if first == second {
|
||||||
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
|
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
|
||||||
@@ -61,7 +61,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
|||||||
resetUserIDCache()
|
resetUserIDCache()
|
||||||
|
|
||||||
key := "api-key-renew"
|
key := "api-key-renew"
|
||||||
id := cachedUserID(key)
|
id := CachedUserID(key)
|
||||||
cacheKey := userIDCacheKey(key)
|
cacheKey := userIDCacheKey(key)
|
||||||
|
|
||||||
soon := time.Now()
|
soon := time.Now()
|
||||||
@@ -72,7 +72,7 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
userIDCacheMu.Unlock()
|
userIDCacheMu.Unlock()
|
||||||
|
|
||||||
if refreshed := cachedUserID(key); refreshed != id {
|
if refreshed := CachedUserID(key); refreshed != id {
|
||||||
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
|
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
|
||||||
}
|
}
|
||||||
|
|
||||||
188
internal/runtime/executor/helps/utls_client.go
Normal file
188
internal/runtime/executor/helps/utls_client.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package helps
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tls "github.com/refraction-networking/utls"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
|
||||||
|
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||||
|
type utlsRoundTripper struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
connections map[string]*http2.ClientConn
|
||||||
|
pending map[string]*sync.Cond
|
||||||
|
dialer proxy.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUtlsRoundTripper(proxyURL string) *utlsRoundTripper {
|
||||||
|
var dialer proxy.Dialer = proxy.Direct
|
||||||
|
if proxyURL != "" {
|
||||||
|
proxyDialer, mode, errBuild := proxyutil.BuildDialer(proxyURL)
|
||||||
|
if errBuild != nil {
|
||||||
|
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyURL, errBuild)
|
||||||
|
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||||
|
dialer = proxyDialer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &utlsRoundTripper{
|
||||||
|
connections: make(map[string]*http2.ClientConn),
|
||||||
|
pending: make(map[string]*sync.Cond),
|
||||||
|
dialer: dialer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||||
|
t.mu.Lock()
|
||||||
|
|
||||||
|
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return h2Conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cond, ok := t.pending[host]; ok {
|
||||||
|
cond.Wait()
|
||||||
|
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return h2Conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cond := sync.NewCond(&t.mu)
|
||||||
|
t.pending[host] = cond
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
h2Conn, err := t.createConnection(host, addr)
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
delete(t.pending, host)
|
||||||
|
cond.Broadcast()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.connections[host] = h2Conn
|
||||||
|
return h2Conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||||
|
conn, err := t.dialer.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{ServerName: host}
|
||||||
|
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||||
|
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tr := &http2.Transport{}
|
||||||
|
h2Conn, err := tr.NewClientConn(tlsConn)
|
||||||
|
if err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return h2Conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
hostname := req.URL.Hostname()
|
||||||
|
port := req.URL.Port()
|
||||||
|
if port == "" {
|
||||||
|
port = "443"
|
||||||
|
}
|
||||||
|
addr := net.JoinHostPort(hostname, port)
|
||||||
|
|
||||||
|
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h2Conn.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.mu.Lock()
|
||||||
|
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
|
||||||
|
delete(t.connections, hostname)
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// anthropicHosts contains the hosts that should use utls Chrome TLS fingerprint.
|
||||||
|
var anthropicHosts = map[string]struct{}{
|
||||||
|
"api.anthropic.com": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallbackRoundTripper uses utls for Anthropic HTTPS hosts and falls back to
|
||||||
|
// standard transport for all other requests (non-HTTPS or non-Anthropic hosts).
|
||||||
|
type fallbackRoundTripper struct {
|
||||||
|
utls *utlsRoundTripper
|
||||||
|
fallback http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fallbackRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
if _, ok := anthropicHosts[strings.ToLower(req.URL.Hostname())]; ok {
|
||||||
|
return f.utls.RoundTrip(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return f.fallback.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUtlsHTTPClient creates an HTTP client using utls Chrome TLS fingerprint.
|
||||||
|
// Use this for Claude API requests to match real Claude Code's TLS behavior.
|
||||||
|
// Falls back to standard transport for non-HTTPS requests.
|
||||||
|
func NewUtlsHTTPClient(cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||||
|
var proxyURL string
|
||||||
|
if auth != nil {
|
||||||
|
proxyURL = strings.TrimSpace(auth.ProxyURL)
|
||||||
|
}
|
||||||
|
if proxyURL == "" && cfg != nil {
|
||||||
|
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
utlsRT := newUtlsRoundTripper(proxyURL)
|
||||||
|
|
||||||
|
var standardTransport http.RoundTripper = &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
}
|
||||||
|
if proxyURL != "" {
|
||||||
|
if transport := buildProxyTransport(proxyURL); transport != nil {
|
||||||
|
standardTransport = transport
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &fallbackRoundTripper{
|
||||||
|
utls: utlsRT,
|
||||||
|
fallback: standardTransport,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if timeout > 0 {
|
||||||
|
client.Timeout = timeout
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -66,7 +67,7 @@ func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,8 +87,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
baseURL = iflowauth.DefaultAPIBaseURL
|
baseURL = iflowauth.DefaultAPIBaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
@@ -106,8 +107,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
}
|
}
|
||||||
|
|
||||||
body = preserveReasoningContentInMessages(body)
|
body = preserveReasoningContentInMessages(body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
|
|
||||||
@@ -116,13 +117,18 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyIFlowHeaders(httpReq, apiKey, false)
|
applyIFlowHeaders(httpReq, apiKey, false)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -134,10 +140,10 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -145,25 +151,25 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
log.Errorf("iflow executor: close response body error: %v", errClose)
|
log.Errorf("iflow executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||||
// Ensure usage is recorded even if upstream omits usage metadata.
|
// Ensure usage is recorded even if upstream omits usage metadata.
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
|
|
||||||
var param any
|
var param any
|
||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
@@ -189,8 +195,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
baseURL = iflowauth.DefaultAPIBaseURL
|
baseURL = iflowauth.DefaultAPIBaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
@@ -214,8 +220,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||||
body = ensureToolsArray(body)
|
body = ensureToolsArray(body)
|
||||||
}
|
}
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
|
|
||||||
@@ -224,13 +230,18 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyIFlowHeaders(httpReq, apiKey, true)
|
applyIFlowHeaders(httpReq, apiKey, true)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -242,21 +253,21 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
data, _ := io.ReadAll(httpResp.Body)
|
data, _ := io.ReadAll(httpResp.Body)
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("iflow executor: close response body error: %v", errClose)
|
log.Errorf("iflow executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -275,9 +286,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
@@ -285,12 +296,12 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
// Guarantee a usage record exists even if the stream never emitted usage data.
|
// Guarantee a usage record exists even if the stream never emitted usage data.
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
@@ -303,17 +314,17 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||||
|
|
||||||
enc, err := tokenizerForModel(baseModel)
|
enc, err := helps.TokenizerForModel(baseModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := countOpenAIChatTokens(enc, body)
|
count, err := helps.CountOpenAIChatTokens(enc, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
usageJSON := buildOpenAIUsageJSON(count)
|
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ import (
|
|||||||
|
|
||||||
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -45,6 +47,11 @@ func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth
|
|||||||
if strings.TrimSpace(token) != "" {
|
if strings.TrimSpace(token) != "" {
|
||||||
req.Header.Set("Authorization", "Bearer "+token)
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
}
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,7 +67,7 @@ func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,8 +83,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
|
|
||||||
token := kimiCreds(auth)
|
token := kimiCreds(auth)
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
originalPayloadSource := req.Payload
|
originalPayloadSource := req.Payload
|
||||||
@@ -100,8 +107,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, err = normalizeKimiToolMessageLinks(body)
|
body, err = normalizeKimiToolMessageLinks(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -113,13 +120,18 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -131,10 +143,10 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -142,21 +154,21 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||||
var param any
|
var param any
|
||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
@@ -176,8 +188,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
token := kimiCreds(auth)
|
token := kimiCreds(auth)
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
originalPayloadSource := req.Payload
|
originalPayloadSource := req.Payload
|
||||||
@@ -204,8 +216,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
||||||
}
|
}
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, err = normalizeKimiToolMessageLinks(body)
|
body, err = normalizeKimiToolMessageLinks(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -217,13 +229,18 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -235,17 +252,17 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -265,9 +282,9 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
@@ -279,8 +296,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -65,15 +66,15 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
|
|||||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
baseURL, apiKey := e.resolveCredentials(auth)
|
baseURL, apiKey := e.resolveCredentials(auth)
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
@@ -95,8 +96,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
originalPayload := originalPayloadSource
|
originalPayload := originalPayloadSource
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
|
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
|
||||||
translated = updated
|
translated = updated
|
||||||
@@ -129,7 +130,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -141,10 +142,10 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -152,23 +153,23 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
body, err := io.ReadAll(httpResp.Body)
|
body, err := io.ReadAll(httpResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
|
||||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
|
||||||
// Ensure we at least record the request even if upstream doesn't return usage
|
// Ensure we at least record the request even if upstream doesn't return usage
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
// Translate response back to source format when needed
|
// Translate response back to source format when needed
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||||
@@ -179,8 +180,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
baseURL, apiKey := e.resolveCredentials(auth)
|
baseURL, apiKey := e.resolveCredentials(auth)
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
@@ -197,8 +198,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
originalPayload := originalPayloadSource
|
originalPayload := originalPayloadSource
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -232,7 +233,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -244,17 +245,17 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -274,9 +275,9 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
if len(line) == 0 {
|
if len(line) == 0 {
|
||||||
continue
|
continue
|
||||||
@@ -294,12 +295,20 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
} else {
|
||||||
|
// In case the upstream close the stream without a terminal [DONE] marker.
|
||||||
|
// Feed a synthetic done marker through the translator so pending
|
||||||
|
// response.completed events are still emitted exactly once.
|
||||||
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), ¶m)
|
||||||
|
for i := range chunks {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Ensure we record the request if no usage chunk was ever seen
|
// Ensure we record the request if no usage chunk was ever seen
|
||||||
reporter.ensurePublished(ctx)
|
reporter.EnsurePublished(ctx)
|
||||||
}()
|
}()
|
||||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
@@ -318,17 +327,17 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
|||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
enc, err := tokenizerForModel(modelForCounting)
|
enc, err := helps.TokenizerForModel(modelForCounting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := countOpenAIChatTokens(enc, translated)
|
count, err := helps.CountOpenAIChatTokens(enc, translated)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
usageJSON := buildOpenAIUsageJSON(count)
|
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||||
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||||
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
|
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ import (
|
|||||||
|
|
||||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -23,20 +25,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
qwenUserAgent = "QwenCode/0.14.2 (darwin; arm64)"
|
||||||
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||||
qwenRateLimitWindow = time.Minute // sliding window duration
|
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||||
)
|
)
|
||||||
|
|
||||||
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
var qwenDefaultSystemMessage = []byte(`{"role":"system","content":[{"type":"text","text":"","cache_control":{"type":"ephemeral"}}]}`)
|
||||||
var qwenBeijingLoc = func() *time.Location {
|
|
||||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
|
||||||
if err != nil || loc == nil {
|
|
||||||
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
|
||||||
return time.FixedZone("CST", 8*3600)
|
|
||||||
}
|
|
||||||
return loc
|
|
||||||
}()
|
|
||||||
|
|
||||||
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||||
var qwenQuotaCodes = map[string]struct{}{
|
var qwenQuotaCodes = map[string]struct{}{
|
||||||
@@ -152,20 +146,110 @@ func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int,
|
|||||||
// Qwen returns 403 for quota errors, 429 for rate limits
|
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||||
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||||
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||||
cooldown := timeUntilNextDay()
|
// Do not force an excessively long retry-after (e.g. until tomorrow), otherwise
|
||||||
retryAfter = &cooldown
|
// the global request-retry scheduler may skip retries due to max-retry-interval.
|
||||||
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
helps.LogWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d)", httpCode, errCode)
|
||||||
}
|
}
|
||||||
return errCode, retryAfter
|
return errCode, retryAfter
|
||||||
}
|
}
|
||||||
|
|
||||||
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
// ensureQwenSystemMessage ensures the request has a single system message at the beginning.
|
||||||
// Qwen's daily quota resets at 00:00 Beijing time.
|
// It always injects the default system prompt and merges any user-provided system messages
|
||||||
func timeUntilNextDay() time.Duration {
|
// into the injected system message content to satisfy Qwen's strict message ordering rules.
|
||||||
now := time.Now()
|
func ensureQwenSystemMessage(payload []byte) ([]byte, error) {
|
||||||
nowLocal := now.In(qwenBeijingLoc)
|
isInjectedSystemPart := func(part gjson.Result) bool {
|
||||||
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
if !part.Exists() || !part.IsObject() {
|
||||||
return tomorrow.Sub(now)
|
return false
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(part.Get("type").String(), "text") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(part.Get("cache_control.type").String(), "ephemeral") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
text := part.Get("text").String()
|
||||||
|
return text == "" || text == "You are Qwen Code."
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultParts := gjson.ParseBytes(qwenDefaultSystemMessage).Get("content")
|
||||||
|
var systemParts []any
|
||||||
|
if defaultParts.Exists() && defaultParts.IsArray() {
|
||||||
|
for _, part := range defaultParts.Array() {
|
||||||
|
systemParts = append(systemParts, part.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(systemParts) == 0 {
|
||||||
|
systemParts = append(systemParts, map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are Qwen Code.",
|
||||||
|
"cache_control": map[string]any{
|
||||||
|
"type": "ephemeral",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
appendSystemContent := func(content gjson.Result) {
|
||||||
|
makeTextPart := func(text string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !content.Exists() || content.Type == gjson.Null {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.IsArray() {
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
if part.Type == gjson.String {
|
||||||
|
systemParts = append(systemParts, makeTextPart(part.String()))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isInjectedSystemPart(part) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
systemParts = append(systemParts, part.Value())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
systemParts = append(systemParts, makeTextPart(content.String()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.IsObject() {
|
||||||
|
if isInjectedSystemPart(content) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
systemParts = append(systemParts, content.Value())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
systemParts = append(systemParts, makeTextPart(content.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := gjson.GetBytes(payload, "messages")
|
||||||
|
var nonSystemMessages []any
|
||||||
|
if messages.Exists() && messages.IsArray() {
|
||||||
|
for _, msg := range messages.Array() {
|
||||||
|
if strings.EqualFold(msg.Get("role").String(), "system") {
|
||||||
|
appendSystemContent(msg.Get("content"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nonSystemMessages = append(nonSystemMessages, msg.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newMessages := make([]any, 0, 1+len(nonSystemMessages))
|
||||||
|
newMessages = append(newMessages, map[string]any{
|
||||||
|
"role": "system",
|
||||||
|
"content": systemParts,
|
||||||
|
})
|
||||||
|
newMessages = append(newMessages, nonSystemMessages...)
|
||||||
|
|
||||||
|
updated, errSet := sjson.SetBytes(payload, "messages", newMessages)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, fmt.Errorf("qwen executor: set system message failed: %w", errSet)
|
||||||
|
}
|
||||||
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||||
@@ -202,7 +286,7 @@ func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
return httpClient.Do(httpReq)
|
return httpClient.Do(httpReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,7 +301,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
}
|
}
|
||||||
if err := checkQwenRateLimit(authID); err != nil {
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -228,8 +312,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
baseURL = "https://portal.qwen.ai/v1"
|
baseURL = "https://portal.qwen.ai/v1"
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
@@ -247,8 +331,12 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, err = ensureQwenSystemMessage(body)
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -256,12 +344,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, false)
|
applyQwenHeaders(httpReq, token, false)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
var authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -273,10 +366,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -284,23 +377,23 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
|
||||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
reporter.Publish(ctx, helps.ParseOpenAIUsage(data))
|
||||||
var param any
|
var param any
|
||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
@@ -320,7 +413,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
}
|
}
|
||||||
if err := checkQwenRateLimit(authID); err != nil {
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
helps.LogWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,8 +424,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
baseURL = "https://portal.qwen.ai/v1"
|
baseURL = "https://portal.qwen.ai/v1"
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
@@ -350,15 +443,19 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
toolsResult := gjson.GetBytes(body, "tools")
|
// toolsResult := gjson.GetBytes(body, "tools")
|
||||||
// I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response.
|
// I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response.
|
||||||
// This will have no real consequences. It's just to scare Qwen3.
|
// This will have no real consequences. It's just to scare Qwen3.
|
||||||
if (toolsResult.IsArray() && len(toolsResult.Array()) == 0) || !toolsResult.Exists() {
|
// if (toolsResult.IsArray() && len(toolsResult.Array()) == 0) || !toolsResult.Exists() {
|
||||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
// body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||||
}
|
// }
|
||||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, err = ensureQwenSystemMessage(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -366,12 +463,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, true)
|
applyQwenHeaders(httpReq, token, true)
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||||
var authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
@@ -383,19 +485,19 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
AuthValue: authValue,
|
AuthValue: authValue,
|
||||||
})
|
})
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
httpResp, err := httpClient.Do(httpReq)
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, err)
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
|
||||||
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -415,9 +517,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.Publish(ctx, detail)
|
||||||
}
|
}
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
@@ -429,8 +531,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||||
}
|
}
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
reporter.publishFailure(ctx)
|
reporter.PublishFailure(ctx)
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -449,17 +551,17 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
modelName = baseModel
|
modelName = baseModel
|
||||||
}
|
}
|
||||||
|
|
||||||
enc, err := tokenizerForModel(modelName)
|
enc, err := helps.TokenizerForModel(modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := countOpenAIChatTokens(enc, body)
|
count, err := helps.CountOpenAIChatTokens(enc, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
usageJSON := buildOpenAIUsageJSON(count)
|
usageJSON := helps.BuildOpenAIUsageJSON(count)
|
||||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||||
}
|
}
|
||||||
@@ -505,20 +607,23 @@ func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
|||||||
}
|
}
|
||||||
|
|
||||||
func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
||||||
r.Header.Set("Content-Type", "application/json")
|
|
||||||
r.Header.Set("Authorization", "Bearer "+token)
|
|
||||||
r.Header.Set("User-Agent", qwenUserAgent)
|
|
||||||
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
|
|
||||||
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
||||||
r.Header.Set("Sec-Fetch-Mode", "cors")
|
r.Header.Set("User-Agent", qwenUserAgent)
|
||||||
r.Header.Set("X-Stainless-Lang", "js")
|
r.Header.Set("X-Stainless-Lang", "js")
|
||||||
r.Header.Set("X-Stainless-Arch", "arm64")
|
r.Header.Set("Accept-Language", "*")
|
||||||
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
|
||||||
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
||||||
r.Header.Set("X-Stainless-Retry-Count", "0")
|
|
||||||
r.Header.Set("X-Stainless-Os", "MacOS")
|
r.Header.Set("X-Stainless-Os", "MacOS")
|
||||||
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
||||||
|
r.Header.Set("X-Stainless-Arch", "arm64")
|
||||||
r.Header.Set("X-Stainless-Runtime", "node")
|
r.Header.Set("X-Stainless-Runtime", "node")
|
||||||
|
r.Header.Set("X-Stainless-Retry-Count", "0")
|
||||||
|
r.Header.Set("Accept-Encoding", "gzip, deflate")
|
||||||
|
r.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
||||||
|
r.Header.Set("Sec-Fetch-Mode", "cors")
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
r.Header.Set("Connection", "keep-alive")
|
||||||
|
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
r.Header.Set("Accept", "text/event-stream")
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
@@ -527,6 +632,26 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
|||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normaliseQwenBaseURL(resourceURL string) string {
|
||||||
|
raw := strings.TrimSpace(resourceURL)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := raw
|
||||||
|
lower := strings.ToLower(normalized)
|
||||||
|
if !strings.HasPrefix(lower, "http://") && !strings.HasPrefix(lower, "https://") {
|
||||||
|
normalized = "https://" + normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized = strings.TrimRight(normalized, "/")
|
||||||
|
if !strings.HasSuffix(strings.ToLower(normalized), "/v1") {
|
||||||
|
normalized += "/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) {
|
func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
@@ -544,7 +669,7 @@ func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) {
|
|||||||
token = v
|
token = v
|
||||||
}
|
}
|
||||||
if v, ok := a.Metadata["resource_url"].(string); ok {
|
if v, ok := a.Metadata["resource_url"].(string); ok {
|
||||||
baseURL = fmt.Sprintf("https://%s/v1", v)
|
baseURL = normaliseQwenBaseURL(v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
package executor
|
package executor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQwenExecutorParseSuffix(t *testing.T) {
|
func TestQwenExecutorParseSuffix(t *testing.T) {
|
||||||
@@ -28,3 +32,180 @@ func TestQwenExecutorParseSuffix(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsureQwenSystemMessage_MergeStringSystem(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"model": "qwen3.6-plus",
|
||||||
|
"stream": true,
|
||||||
|
"messages": [
|
||||||
|
{ "role": "system", "content": "ABCDEFG" },
|
||||||
|
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := ensureQwenSystemMessage(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages").Array()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Get("role").String() != "system" {
|
||||||
|
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
|
||||||
|
}
|
||||||
|
parts := msgs[0].Get("content").Array()
|
||||||
|
if len(parts) != 2 {
|
||||||
|
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
|
||||||
|
}
|
||||||
|
if parts[0].Get("type").String() != "text" || parts[0].Get("cache_control.type").String() != "ephemeral" {
|
||||||
|
t.Fatalf("messages[0].content[0] = %s, want injected system part", parts[0].Raw)
|
||||||
|
}
|
||||||
|
if text := parts[0].Get("text").String(); text != "" && text != "You are Qwen Code." {
|
||||||
|
t.Fatalf("messages[0].content[0].text = %q, want empty string or default prompt", text)
|
||||||
|
}
|
||||||
|
if parts[1].Get("type").String() != "text" || parts[1].Get("text").String() != "ABCDEFG" {
|
||||||
|
t.Fatalf("messages[0].content[1] = %s, want text part with ABCDEFG", parts[1].Raw)
|
||||||
|
}
|
||||||
|
if msgs[1].Get("role").String() != "user" {
|
||||||
|
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureQwenSystemMessage_MergeObjectSystem(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"messages": [
|
||||||
|
{ "role": "system", "content": { "type": "text", "text": "ABCDEFG" } },
|
||||||
|
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := ensureQwenSystemMessage(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages").Array()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||||
|
}
|
||||||
|
parts := msgs[0].Get("content").Array()
|
||||||
|
if len(parts) != 2 {
|
||||||
|
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
|
||||||
|
}
|
||||||
|
if parts[1].Get("text").String() != "ABCDEFG" {
|
||||||
|
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "ABCDEFG")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureQwenSystemMessage_PrependsWhenMissing(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"messages": [
|
||||||
|
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := ensureQwenSystemMessage(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages").Array()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Get("role").String() != "system" {
|
||||||
|
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
|
||||||
|
}
|
||||||
|
if !msgs[0].Get("content").IsArray() || len(msgs[0].Get("content").Array()) == 0 {
|
||||||
|
t.Fatalf("messages[0].content = %s, want non-empty array", msgs[0].Get("content").Raw)
|
||||||
|
}
|
||||||
|
if msgs[1].Get("role").String() != "user" {
|
||||||
|
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureQwenSystemMessage_MergesMultipleSystemMessages(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"messages": [
|
||||||
|
{ "role": "system", "content": "A" },
|
||||||
|
{ "role": "user", "content": [ { "type": "text", "text": "hi" } ] },
|
||||||
|
{ "role": "system", "content": "B" }
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := ensureQwenSystemMessage(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages").Array()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||||
|
}
|
||||||
|
parts := msgs[0].Get("content").Array()
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("messages[0].content length = %d, want 3", len(parts))
|
||||||
|
}
|
||||||
|
if parts[1].Get("text").String() != "A" {
|
||||||
|
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "A")
|
||||||
|
}
|
||||||
|
if parts[2].Get("text").String() != "B" {
|
||||||
|
t.Fatalf("messages[0].content[2].text = %q, want %q", parts[2].Get("text").String(), "B")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapQwenError_InsufficientQuotaDoesNotSetRetryAfter(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"code":"insufficient_quota","message":"You exceeded your current quota","type":"insufficient_quota"}}`)
|
||||||
|
code, retryAfter := wrapQwenError(context.Background(), http.StatusTooManyRequests, body)
|
||||||
|
if code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("wrapQwenError status = %d, want %d", code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if retryAfter != nil {
|
||||||
|
t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapQwenError_Maps403QuotaTo429WithoutRetryAfter(t *testing.T) {
|
||||||
|
body := []byte(`{"error":{"code":"insufficient_quota","message":"You exceeded your current quota","type":"insufficient_quota"}}`)
|
||||||
|
code, retryAfter := wrapQwenError(context.Background(), http.StatusForbidden, body)
|
||||||
|
if code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("wrapQwenError status = %d, want %d", code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if retryAfter != nil {
|
||||||
|
t.Fatalf("wrapQwenError retryAfter = %v, want nil", *retryAfter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwenCreds_NormalizesResourceURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
resourceURL string
|
||||||
|
wantBaseURL string
|
||||||
|
}{
|
||||||
|
{"host only", "portal.qwen.ai", "https://portal.qwen.ai/v1"},
|
||||||
|
{"scheme no v1", "https://portal.qwen.ai", "https://portal.qwen.ai/v1"},
|
||||||
|
{"scheme with v1", "https://portal.qwen.ai/v1", "https://portal.qwen.ai/v1"},
|
||||||
|
{"scheme with v1 slash", "https://portal.qwen.ai/v1/", "https://portal.qwen.ai/v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"access_token": "test-token",
|
||||||
|
"resource_url": tt.resourceURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token, baseURL := qwenCreds(auth)
|
||||||
|
if token != "test-token" {
|
||||||
|
t.Fatalf("qwenCreds token = %q, want %q", token, "test-token")
|
||||||
|
}
|
||||||
|
if baseURL != tt.wantBaseURL {
|
||||||
|
t.Fatalf("qwenCreds baseURL = %q, want %q", baseURL, tt.wantBaseURL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,16 +32,24 @@ type GitTokenStore struct {
|
|||||||
repoDir string
|
repoDir string
|
||||||
configDir string
|
configDir string
|
||||||
remote string
|
remote string
|
||||||
|
branch string
|
||||||
username string
|
username string
|
||||||
password string
|
password string
|
||||||
lastGC time.Time
|
lastGC time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type resolvedRemoteBranch struct {
|
||||||
|
name plumbing.ReferenceName
|
||||||
|
hash plumbing.Hash
|
||||||
|
}
|
||||||
|
|
||||||
// NewGitTokenStore creates a token store that saves credentials to disk through the
|
// NewGitTokenStore creates a token store that saves credentials to disk through the
|
||||||
// TokenStorage implementation embedded in the token record.
|
// TokenStorage implementation embedded in the token record.
|
||||||
func NewGitTokenStore(remote, username, password string) *GitTokenStore {
|
// When branch is non-empty, clone/pull/push operations target that branch instead of the remote default.
|
||||||
|
func NewGitTokenStore(remote, username, password, branch string) *GitTokenStore {
|
||||||
return &GitTokenStore{
|
return &GitTokenStore{
|
||||||
remote: remote,
|
remote: remote,
|
||||||
|
branch: strings.TrimSpace(branch),
|
||||||
username: username,
|
username: username,
|
||||||
password: password,
|
password: password,
|
||||||
}
|
}
|
||||||
@@ -120,7 +128,11 @@ func (s *GitTokenStore) EnsureRepository() error {
|
|||||||
s.dirLock.Unlock()
|
s.dirLock.Unlock()
|
||||||
return fmt.Errorf("git token store: create repo dir: %w", errMk)
|
return fmt.Errorf("git token store: create repo dir: %w", errMk)
|
||||||
}
|
}
|
||||||
if _, errClone := git.PlainClone(repoDir, &git.CloneOptions{Auth: authMethod, URL: s.remote}); errClone != nil {
|
cloneOpts := &git.CloneOptions{Auth: authMethod, URL: s.remote}
|
||||||
|
if s.branch != "" {
|
||||||
|
cloneOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
|
||||||
|
}
|
||||||
|
if _, errClone := git.PlainClone(repoDir, cloneOpts); errClone != nil {
|
||||||
if errors.Is(errClone, transport.ErrEmptyRemoteRepository) {
|
if errors.Is(errClone, transport.ErrEmptyRemoteRepository) {
|
||||||
_ = os.RemoveAll(gitDir)
|
_ = os.RemoveAll(gitDir)
|
||||||
repo, errInit := git.PlainInit(repoDir, false)
|
repo, errInit := git.PlainInit(repoDir, false)
|
||||||
@@ -128,6 +140,13 @@ func (s *GitTokenStore) EnsureRepository() error {
|
|||||||
s.dirLock.Unlock()
|
s.dirLock.Unlock()
|
||||||
return fmt.Errorf("git token store: init empty repo: %w", errInit)
|
return fmt.Errorf("git token store: init empty repo: %w", errInit)
|
||||||
}
|
}
|
||||||
|
if s.branch != "" {
|
||||||
|
headRef := plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(s.branch))
|
||||||
|
if errHead := repo.Storer.SetReference(headRef); errHead != nil {
|
||||||
|
s.dirLock.Unlock()
|
||||||
|
return fmt.Errorf("git token store: set head to branch %s: %w", s.branch, errHead)
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, errRemote := repo.Remote("origin"); errRemote != nil {
|
if _, errRemote := repo.Remote("origin"); errRemote != nil {
|
||||||
if _, errCreate := repo.CreateRemote(&config.RemoteConfig{
|
if _, errCreate := repo.CreateRemote(&config.RemoteConfig{
|
||||||
Name: "origin",
|
Name: "origin",
|
||||||
@@ -176,16 +195,39 @@ func (s *GitTokenStore) EnsureRepository() error {
|
|||||||
s.dirLock.Unlock()
|
s.dirLock.Unlock()
|
||||||
return fmt.Errorf("git token store: worktree: %w", errWorktree)
|
return fmt.Errorf("git token store: worktree: %w", errWorktree)
|
||||||
}
|
}
|
||||||
if errPull := worktree.Pull(&git.PullOptions{Auth: authMethod, RemoteName: "origin"}); errPull != nil {
|
if s.branch != "" {
|
||||||
|
if errCheckout := s.checkoutConfiguredBranch(repo, worktree, authMethod); errCheckout != nil {
|
||||||
|
s.dirLock.Unlock()
|
||||||
|
return errCheckout
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// When branch is unset, ensure the working tree follows the remote default branch
|
||||||
|
if err := checkoutRemoteDefaultBranch(repo, worktree, authMethod); err != nil {
|
||||||
|
if !shouldFallbackToCurrentBranch(repo, err) {
|
||||||
|
s.dirLock.Unlock()
|
||||||
|
return fmt.Errorf("git token store: checkout remote default: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pullOpts := &git.PullOptions{Auth: authMethod, RemoteName: "origin"}
|
||||||
|
if s.branch != "" {
|
||||||
|
pullOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
|
||||||
|
}
|
||||||
|
if errPull := worktree.Pull(pullOpts); errPull != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(errPull, git.NoErrAlreadyUpToDate),
|
case errors.Is(errPull, git.NoErrAlreadyUpToDate),
|
||||||
errors.Is(errPull, git.ErrUnstagedChanges),
|
errors.Is(errPull, git.ErrUnstagedChanges),
|
||||||
errors.Is(errPull, git.ErrNonFastForwardUpdate):
|
errors.Is(errPull, git.ErrNonFastForwardUpdate):
|
||||||
// Ignore clean syncs, local edits, and remote divergence—local changes win.
|
// Ignore clean syncs, local edits, and remote divergence—local changes win.
|
||||||
case errors.Is(errPull, transport.ErrAuthenticationRequired),
|
case errors.Is(errPull, transport.ErrAuthenticationRequired),
|
||||||
errors.Is(errPull, plumbing.ErrReferenceNotFound),
|
|
||||||
errors.Is(errPull, transport.ErrEmptyRemoteRepository):
|
errors.Is(errPull, transport.ErrEmptyRemoteRepository):
|
||||||
// Ignore authentication prompts and empty remote references on initial sync.
|
// Ignore authentication prompts and empty remote references on initial sync.
|
||||||
|
case errors.Is(errPull, plumbing.ErrReferenceNotFound):
|
||||||
|
if s.branch != "" {
|
||||||
|
s.dirLock.Unlock()
|
||||||
|
return fmt.Errorf("git token store: pull: %w", errPull)
|
||||||
|
}
|
||||||
|
// Ignore missing references only when following the remote default branch.
|
||||||
default:
|
default:
|
||||||
s.dirLock.Unlock()
|
s.dirLock.Unlock()
|
||||||
return fmt.Errorf("git token store: pull: %w", errPull)
|
return fmt.Errorf("git token store: pull: %w", errPull)
|
||||||
@@ -446,6 +488,7 @@ func (s *GitTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
|||||||
if email, ok := metadata["email"].(string); ok && email != "" {
|
if email, ok := metadata["email"].(string); ok && email != "" {
|
||||||
auth.Attributes["email"] = email
|
auth.Attributes["email"] = email
|
||||||
}
|
}
|
||||||
|
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -553,6 +596,192 @@ func (s *GitTokenStore) relativeToRepo(path string) (string, error) {
|
|||||||
return rel, nil
|
return rel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GitTokenStore) checkoutConfiguredBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
|
||||||
|
branchRefName := plumbing.NewBranchReferenceName(s.branch)
|
||||||
|
headRef, errHead := repo.Head()
|
||||||
|
switch {
|
||||||
|
case errHead == nil && headRef.Name() == branchRefName:
|
||||||
|
return nil
|
||||||
|
case errHead != nil && !errors.Is(errHead, plumbing.ErrReferenceNotFound):
|
||||||
|
return fmt.Errorf("git token store: get head: %w", errHead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err == nil {
|
||||||
|
return nil
|
||||||
|
} else if _, errRef := repo.Reference(branchRefName, true); errRef == nil {
|
||||||
|
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
|
||||||
|
} else if !errors.Is(errRef, plumbing.ErrReferenceNotFound) {
|
||||||
|
return fmt.Errorf("git token store: inspect branch %s: %w", s.branch, errRef)
|
||||||
|
} else if err := s.checkoutConfiguredRemoteTrackingBranch(repo, worktree, branchRefName, authMethod); err != nil {
|
||||||
|
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitTokenStore) checkoutConfiguredRemoteTrackingBranch(repo *git.Repository, worktree *git.Worktree, branchRefName plumbing.ReferenceName, authMethod transport.AuthMethod) error {
|
||||||
|
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + s.branch)
|
||||||
|
remoteRef, err := repo.Reference(remoteRefName, true)
|
||||||
|
if errors.Is(err, plumbing.ErrReferenceNotFound) {
|
||||||
|
if errSync := syncRemoteReferences(repo, authMethod); errSync != nil {
|
||||||
|
return fmt.Errorf("sync remote refs: %w", errSync)
|
||||||
|
}
|
||||||
|
remoteRef, err = repo.Reference(remoteRefName, true)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: remoteRef.Hash()}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := repo.Config()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("git token store: repo config: %w", err)
|
||||||
|
}
|
||||||
|
if _, ok := cfg.Branches[s.branch]; !ok {
|
||||||
|
cfg.Branches[s.branch] = &config.Branch{Name: s.branch}
|
||||||
|
}
|
||||||
|
cfg.Branches[s.branch].Remote = "origin"
|
||||||
|
cfg.Branches[s.branch].Merge = branchRefName
|
||||||
|
if err := repo.SetConfig(cfg); err != nil {
|
||||||
|
return fmt.Errorf("git token store: set branch config: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func syncRemoteReferences(repo *git.Repository, authMethod transport.AuthMethod) error {
|
||||||
|
if err := repo.Fetch(&git.FetchOptions{Auth: authMethod, RemoteName: "origin"}); err != nil && !errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRemoteDefaultBranch queries the origin remote to determine the remote's default branch
|
||||||
|
// (the target of HEAD) and returns the corresponding local branch reference name (e.g. refs/heads/master).
|
||||||
|
func resolveRemoteDefaultBranch(repo *git.Repository, authMethod transport.AuthMethod) (resolvedRemoteBranch, error) {
|
||||||
|
if err := syncRemoteReferences(repo, authMethod); err != nil {
|
||||||
|
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: sync remote refs: %w", err)
|
||||||
|
}
|
||||||
|
remote, err := repo.Remote("origin")
|
||||||
|
if err != nil {
|
||||||
|
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: get remote: %w", err)
|
||||||
|
}
|
||||||
|
refs, err := remote.List(&git.ListOptions{Auth: authMethod})
|
||||||
|
if err != nil {
|
||||||
|
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
|
||||||
|
return resolved, nil
|
||||||
|
}
|
||||||
|
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: list remote refs: %w", err)
|
||||||
|
}
|
||||||
|
for _, r := range refs {
|
||||||
|
if r.Name() == plumbing.HEAD {
|
||||||
|
if r.Type() == plumbing.SymbolicReference {
|
||||||
|
if target, ok := normalizeRemoteBranchReference(r.Target()); ok {
|
||||||
|
return resolvedRemoteBranch{name: target}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s := r.String()
|
||||||
|
if idx := strings.Index(s, "->"); idx != -1 {
|
||||||
|
if target, ok := normalizeRemoteBranchReference(plumbing.ReferenceName(strings.TrimSpace(s[idx+2:]))); ok {
|
||||||
|
return resolvedRemoteBranch{name: target}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
|
||||||
|
return resolved, nil
|
||||||
|
}
|
||||||
|
for _, r := range refs {
|
||||||
|
if normalized, ok := normalizeRemoteBranchReference(r.Name()); ok {
|
||||||
|
return resolvedRemoteBranch{name: normalized, hash: r.Hash()}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: remote default branch not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveRemoteDefaultBranchFromLocal(repo *git.Repository) (resolvedRemoteBranch, bool) {
|
||||||
|
ref, err := repo.Reference(plumbing.ReferenceName("refs/remotes/origin/HEAD"), true)
|
||||||
|
if err != nil || ref.Type() != plumbing.SymbolicReference {
|
||||||
|
return resolvedRemoteBranch{}, false
|
||||||
|
}
|
||||||
|
target, ok := normalizeRemoteBranchReference(ref.Target())
|
||||||
|
if !ok {
|
||||||
|
return resolvedRemoteBranch{}, false
|
||||||
|
}
|
||||||
|
return resolvedRemoteBranch{name: target}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeRemoteBranchReference(name plumbing.ReferenceName) (plumbing.ReferenceName, bool) {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(name.String(), "refs/heads/"):
|
||||||
|
return name, true
|
||||||
|
case strings.HasPrefix(name.String(), "refs/remotes/origin/"):
|
||||||
|
return plumbing.NewBranchReferenceName(strings.TrimPrefix(name.String(), "refs/remotes/origin/")), true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldFallbackToCurrentBranch(repo *git.Repository, err error) bool {
|
||||||
|
if !errors.Is(err, transport.ErrAuthenticationRequired) && !errors.Is(err, transport.ErrEmptyRemoteRepository) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, headErr := repo.Head()
|
||||||
|
return headErr == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkoutRemoteDefaultBranch ensures the working tree is checked out to the remote's default branch
|
||||||
|
// (the branch target of origin/HEAD). If the local branch does not exist it will be created to track
|
||||||
|
// the remote branch.
|
||||||
|
func checkoutRemoteDefaultBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
|
||||||
|
resolved, err := resolveRemoteDefaultBranch(repo, authMethod)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
branchRefName := resolved.name
|
||||||
|
// If HEAD already points to the desired branch, nothing to do.
|
||||||
|
headRef, errHead := repo.Head()
|
||||||
|
if errHead == nil && headRef.Name() == branchRefName {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// If local branch exists, attempt a checkout
|
||||||
|
if _, err := repo.Reference(branchRefName, true); err == nil {
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err != nil {
|
||||||
|
return fmt.Errorf("checkout branch %s: %w", branchRefName.String(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Try to find the corresponding remote tracking ref (refs/remotes/origin/<name>)
|
||||||
|
branchShort := strings.TrimPrefix(branchRefName.String(), "refs/heads/")
|
||||||
|
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + branchShort)
|
||||||
|
hash := resolved.hash
|
||||||
|
if remoteRef, err := repo.Reference(remoteRefName, true); err == nil {
|
||||||
|
hash = remoteRef.Hash()
|
||||||
|
} else if err != nil && !errors.Is(err, plumbing.ErrReferenceNotFound) {
|
||||||
|
return fmt.Errorf("checkout remote default: remote ref %s: %w", remoteRefName.String(), err)
|
||||||
|
}
|
||||||
|
if hash == plumbing.ZeroHash {
|
||||||
|
return fmt.Errorf("checkout remote default: remote ref %s not found", remoteRefName.String())
|
||||||
|
}
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: hash}); err != nil {
|
||||||
|
return fmt.Errorf("checkout create branch %s: %w", branchRefName.String(), err)
|
||||||
|
}
|
||||||
|
cfg, err := repo.Config()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("git token store: repo config: %w", err)
|
||||||
|
}
|
||||||
|
if _, ok := cfg.Branches[branchShort]; !ok {
|
||||||
|
cfg.Branches[branchShort] = &config.Branch{Name: branchShort}
|
||||||
|
}
|
||||||
|
cfg.Branches[branchShort].Remote = "origin"
|
||||||
|
cfg.Branches[branchShort].Merge = branchRefName
|
||||||
|
if err := repo.SetConfig(cfg); err != nil {
|
||||||
|
return fmt.Errorf("git token store: set branch config: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error {
|
func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error {
|
||||||
repoDir := s.repoDirSnapshot()
|
repoDir := s.repoDirSnapshot()
|
||||||
if repoDir == "" {
|
if repoDir == "" {
|
||||||
@@ -618,7 +847,16 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
|
|||||||
return errRewrite
|
return errRewrite
|
||||||
}
|
}
|
||||||
s.maybeRunGC(repo)
|
s.maybeRunGC(repo)
|
||||||
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
|
pushOpts := &git.PushOptions{Auth: s.gitAuth(), Force: true}
|
||||||
|
if s.branch != "" {
|
||||||
|
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec("refs/heads/" + s.branch + ":refs/heads/" + s.branch)}
|
||||||
|
} else {
|
||||||
|
// When branch is unset, pin push to the currently checked-out branch.
|
||||||
|
if headRef, err := repo.Head(); err == nil {
|
||||||
|
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec(headRef.Name().String() + ":" + headRef.Name().String())}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = repo.Push(pushOpts); err != nil {
|
||||||
if errors.Is(err, git.NoErrAlreadyUpToDate) {
|
if errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
585
internal/store/gitstore_test.go
Normal file
585
internal/store/gitstore_test.go
Normal file
@@ -0,0 +1,585 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-git/go-git/v6"
|
||||||
|
gitconfig "github.com/go-git/go-git/v6/config"
|
||||||
|
"github.com/go-git/go-git/v6/plumbing"
|
||||||
|
"github.com/go-git/go-git/v6/plumbing/object"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testBranchSpec struct {
|
||||||
|
name string
|
||||||
|
contents string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryUsesRemoteDefaultBranchWhenBranchNotConfigured(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "trunk",
|
||||||
|
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
|
||||||
|
testBranchSpec{name: "release/2026", contents: "release branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||||
|
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch\n")
|
||||||
|
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
|
||||||
|
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository second call: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch updated\n")
|
||||||
|
assertRemoteHeadBranch(t, remoteDir, "trunk")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryUsesConfiguredBranchWhenExplicitlySet(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "trunk",
|
||||||
|
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
|
||||||
|
testBranchSpec{name: "release/2026", contents: "release branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
store := NewGitTokenStore(remoteDir, "", "", "release/2026")
|
||||||
|
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
|
||||||
|
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
|
||||||
|
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository second call: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch updated\n")
|
||||||
|
assertRemoteHeadBranch(t, remoteDir, "trunk")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranch(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "trunk",
|
||||||
|
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
store := NewGitTokenStore(remoteDir, "", "", "missing-branch")
|
||||||
|
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
|
||||||
|
|
||||||
|
err := store.EnsureRepository()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("EnsureRepository succeeded, want error for nonexistent configured branch")
|
||||||
|
}
|
||||||
|
assertRemoteHeadBranch(t, remoteDir, "trunk")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranchOnExistingRepositoryPull(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "trunk",
|
||||||
|
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
baseDir := filepath.Join(root, "workspace", "auths")
|
||||||
|
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||||
|
store.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository initial clone: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reopened := NewGitTokenStore(remoteDir, "", "", "missing-branch")
|
||||||
|
reopened.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
err := reopened.EnsureRepository()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("EnsureRepository succeeded on reopen, want error for nonexistent configured branch")
|
||||||
|
}
|
||||||
|
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "trunk")
|
||||||
|
assertRemoteHeadBranch(t, remoteDir, "trunk")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryInitializesEmptyRemoteUsingConfiguredBranch(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := filepath.Join(root, "remote.git")
|
||||||
|
if _, err := git.PlainInit(remoteDir, true); err != nil {
|
||||||
|
t.Fatalf("init bare remote: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
branch := "feature/gemini-fix"
|
||||||
|
store := NewGitTokenStore(remoteDir, "", "", branch)
|
||||||
|
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), branch)
|
||||||
|
assertRemoteBranchExistsWithCommit(t, remoteDir, branch)
|
||||||
|
assertRemoteBranchDoesNotExist(t, remoteDir, "master")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranch(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||||
|
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||||
|
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
baseDir := filepath.Join(root, "workspace", "auths")
|
||||||
|
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||||
|
store.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository initial clone: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
|
||||||
|
|
||||||
|
reopened := NewGitTokenStore(remoteDir, "", "", "develop")
|
||||||
|
reopened.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
if err := reopened.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository reopen: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
|
||||||
|
|
||||||
|
workspaceDir := filepath.Join(root, "workspace")
|
||||||
|
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local develop update\n"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write local branch marker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reopened.mu.Lock()
|
||||||
|
err := reopened.commitAndPushLocked("Update develop branch marker", "branch.txt")
|
||||||
|
reopened.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("commitAndPushLocked: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertRepositoryHeadBranch(t, workspaceDir, "develop")
|
||||||
|
assertRemoteBranchContents(t, remoteDir, "develop", "local develop update\n")
|
||||||
|
assertRemoteBranchContents(t, remoteDir, "master", "remote master branch\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranchCreatedAfterClone(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||||
|
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
baseDir := filepath.Join(root, "workspace", "auths")
|
||||||
|
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||||
|
store.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository initial clone: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
|
||||||
|
|
||||||
|
advanceRemoteBranchFromNewBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch\n", "create release")
|
||||||
|
|
||||||
|
reopened := NewGitTokenStore(remoteDir, "", "", "release/2026")
|
||||||
|
reopened.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
if err := reopened.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository reopen: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryResetsToRemoteDefaultWhenBranchUnset(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||||
|
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||||
|
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
baseDir := filepath.Join(root, "workspace", "auths")
|
||||||
|
// First store pins to develop and prepares local workspace
|
||||||
|
storePinned := NewGitTokenStore(remoteDir, "", "", "develop")
|
||||||
|
storePinned.SetBaseDir(baseDir)
|
||||||
|
if err := storePinned.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository pinned: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
|
||||||
|
|
||||||
|
// Second store has branch unset and should reset local workspace to remote default (master)
|
||||||
|
storeDefault := NewGitTokenStore(remoteDir, "", "", "")
|
||||||
|
storeDefault.SetBaseDir(baseDir)
|
||||||
|
if err := storeDefault.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository default: %v", err)
|
||||||
|
}
|
||||||
|
// Local HEAD should now follow remote default (master)
|
||||||
|
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "master")
|
||||||
|
|
||||||
|
// Make a local change and push using the store with branch unset; push should update remote master
|
||||||
|
workspaceDir := filepath.Join(root, "workspace")
|
||||||
|
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local master update\n"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write local master marker: %v", err)
|
||||||
|
}
|
||||||
|
storeDefault.mu.Lock()
|
||||||
|
if err := storeDefault.commitAndPushLocked("Update master marker", "branch.txt"); err != nil {
|
||||||
|
storeDefault.mu.Unlock()
|
||||||
|
t.Fatalf("commitAndPushLocked: %v", err)
|
||||||
|
}
|
||||||
|
storeDefault.mu.Unlock()
|
||||||
|
|
||||||
|
assertRemoteBranchContents(t, remoteDir, "master", "local master update\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryFollowsRenamedRemoteDefaultBranchWhenAvailable(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||||
|
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||||
|
testBranchSpec{name: "main", contents: "remote main branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
baseDir := filepath.Join(root, "workspace", "auths")
|
||||||
|
store := NewGitTokenStore(remoteDir, "", "", "")
|
||||||
|
store.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
if err := store.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository initial clone: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
|
||||||
|
|
||||||
|
setRemoteHeadBranch(t, remoteDir, "main")
|
||||||
|
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "main", "remote main branch updated\n", "advance main")
|
||||||
|
|
||||||
|
reopened := NewGitTokenStore(remoteDir, "", "", "")
|
||||||
|
reopened.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
if err := reopened.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository after remote default rename: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "main", "remote main branch updated\n")
|
||||||
|
assertRemoteHeadBranch(t, remoteDir, "main")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureRepositoryKeepsCurrentBranchWhenRemoteDefaultCannotBeResolved(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
remoteDir := setupGitRemoteRepository(t, root, "master",
|
||||||
|
testBranchSpec{name: "master", contents: "remote master branch\n"},
|
||||||
|
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
|
||||||
|
)
|
||||||
|
|
||||||
|
baseDir := filepath.Join(root, "workspace", "auths")
|
||||||
|
pinned := NewGitTokenStore(remoteDir, "", "", "develop")
|
||||||
|
pinned.SetBaseDir(baseDir)
|
||||||
|
if err := pinned.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository pinned: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
|
||||||
|
|
||||||
|
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="git"`)
|
||||||
|
http.Error(w, "auth required", http.StatusUnauthorized)
|
||||||
|
}))
|
||||||
|
defer authServer.Close()
|
||||||
|
|
||||||
|
repo, err := git.PlainOpen(filepath.Join(root, "workspace"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open workspace repo: %v", err)
|
||||||
|
}
|
||||||
|
cfg, err := repo.Config()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read repo config: %v", err)
|
||||||
|
}
|
||||||
|
cfg.Remotes["origin"].URLs = []string{authServer.URL}
|
||||||
|
if err := repo.SetConfig(cfg); err != nil {
|
||||||
|
t.Fatalf("set repo config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reopened := NewGitTokenStore(remoteDir, "", "", "")
|
||||||
|
reopened.SetBaseDir(baseDir)
|
||||||
|
|
||||||
|
if err := reopened.EnsureRepository(); err != nil {
|
||||||
|
t.Fatalf("EnsureRepository default branch fallback: %v", err)
|
||||||
|
}
|
||||||
|
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "develop")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupGitRemoteRepository(t *testing.T, root, defaultBranch string, branches ...testBranchSpec) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
remoteDir := filepath.Join(root, "remote.git")
|
||||||
|
if _, err := git.PlainInit(remoteDir, true); err != nil {
|
||||||
|
t.Fatalf("init bare remote: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedDir := filepath.Join(root, "seed")
|
||||||
|
seedRepo, err := git.PlainInit(seedDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("init seed repo: %v", err)
|
||||||
|
}
|
||||||
|
if err := seedRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
|
||||||
|
t.Fatalf("set seed HEAD: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
worktree, err := seedRepo.Worktree()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open seed worktree: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSpec, ok := findBranchSpec(branches, defaultBranch)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("missing default branch spec for %q", defaultBranch)
|
||||||
|
}
|
||||||
|
commitBranchMarker(t, seedDir, worktree, defaultSpec, "seed default branch")
|
||||||
|
|
||||||
|
for _, branch := range branches {
|
||||||
|
if branch.name == defaultBranch {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(defaultBranch)}); err != nil {
|
||||||
|
t.Fatalf("checkout default branch %s: %v", defaultBranch, err)
|
||||||
|
}
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch.name), Create: true}); err != nil {
|
||||||
|
t.Fatalf("create branch %s: %v", branch.name, err)
|
||||||
|
}
|
||||||
|
commitBranchMarker(t, seedDir, worktree, branch, "seed branch "+branch.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := seedRepo.CreateRemote(&gitconfig.RemoteConfig{Name: "origin", URLs: []string{remoteDir}}); err != nil {
|
||||||
|
t.Fatalf("create origin remote: %v", err)
|
||||||
|
}
|
||||||
|
if err := seedRepo.Push(&git.PushOptions{
|
||||||
|
RemoteName: "origin",
|
||||||
|
RefSpecs: []gitconfig.RefSpec{gitconfig.RefSpec("refs/heads/*:refs/heads/*")},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("push seed branches: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open remote repo: %v", err)
|
||||||
|
}
|
||||||
|
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
|
||||||
|
t.Fatalf("set remote HEAD: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return remoteDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func commitBranchMarker(t *testing.T, seedDir string, worktree *git.Worktree, branch testBranchSpec, message string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err := os.WriteFile(filepath.Join(seedDir, "branch.txt"), []byte(branch.contents), 0o600); err != nil {
|
||||||
|
t.Fatalf("write branch marker for %s: %v", branch.name, err)
|
||||||
|
}
|
||||||
|
if _, err := worktree.Add("branch.txt"); err != nil {
|
||||||
|
t.Fatalf("add branch marker for %s: %v", branch.name, err)
|
||||||
|
}
|
||||||
|
if _, err := worktree.Commit(message, &git.CommitOptions{
|
||||||
|
Author: &object.Signature{
|
||||||
|
Name: "CLIProxyAPI",
|
||||||
|
Email: "cliproxy@local",
|
||||||
|
When: time.Unix(1711929600, 0),
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("commit branch marker for %s: %v", branch.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func advanceRemoteBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
seedRepo, err := git.PlainOpen(seedDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open seed repo: %v", err)
|
||||||
|
}
|
||||||
|
worktree, err := seedRepo.Worktree()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open seed worktree: %v", err)
|
||||||
|
}
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch)}); err != nil {
|
||||||
|
t.Fatalf("checkout branch %s: %v", branch, err)
|
||||||
|
}
|
||||||
|
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
|
||||||
|
if err := seedRepo.Push(&git.PushOptions{
|
||||||
|
RemoteName: "origin",
|
||||||
|
RefSpecs: []gitconfig.RefSpec{
|
||||||
|
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("push branch %s update to %s: %v", branch, remoteDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func advanceRemoteBranchFromNewBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
seedRepo, err := git.PlainOpen(seedDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open seed repo: %v", err)
|
||||||
|
}
|
||||||
|
worktree, err := seedRepo.Worktree()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open seed worktree: %v", err)
|
||||||
|
}
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName("master")}); err != nil {
|
||||||
|
t.Fatalf("checkout master before creating %s: %v", branch, err)
|
||||||
|
}
|
||||||
|
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch), Create: true}); err != nil {
|
||||||
|
t.Fatalf("create branch %s: %v", branch, err)
|
||||||
|
}
|
||||||
|
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
|
||||||
|
if err := seedRepo.Push(&git.PushOptions{
|
||||||
|
RemoteName: "origin",
|
||||||
|
RefSpecs: []gitconfig.RefSpec{
|
||||||
|
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("push new branch %s update to %s: %v", branch, remoteDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findBranchSpec(branches []testBranchSpec, name string) (testBranchSpec, bool) {
|
||||||
|
for _, branch := range branches {
|
||||||
|
if branch.name == name {
|
||||||
|
return branch, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return testBranchSpec{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertRepositoryBranchAndContents(t *testing.T, repoDir, branch, wantContents string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
repo, err := git.PlainOpen(repoDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open local repo: %v", err)
|
||||||
|
}
|
||||||
|
head, err := repo.Head()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("local repo head: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
|
||||||
|
t.Fatalf("local head branch = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
contents, err := os.ReadFile(filepath.Join(repoDir, "branch.txt"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read branch marker: %v", err)
|
||||||
|
}
|
||||||
|
if got := string(contents); got != wantContents {
|
||||||
|
t.Fatalf("branch marker contents = %q, want %q", got, wantContents)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertRepositoryHeadBranch(t *testing.T, repoDir, branch string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
repo, err := git.PlainOpen(repoDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open local repo: %v", err)
|
||||||
|
}
|
||||||
|
head, err := repo.Head()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("local repo head: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
|
||||||
|
t.Fatalf("local head branch = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open remote repo: %v", err)
|
||||||
|
}
|
||||||
|
head, err := remoteRepo.Reference(plumbing.HEAD, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read remote HEAD: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := head.Target(), plumbing.NewBranchReferenceName(branch); got != want {
|
||||||
|
t.Fatalf("remote HEAD target = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open remote repo: %v", err)
|
||||||
|
}
|
||||||
|
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(branch))); err != nil {
|
||||||
|
t.Fatalf("set remote HEAD to %s: %v", branch, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertRemoteBranchExistsWithCommit(t *testing.T, remoteDir, branch string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open remote repo: %v", err)
|
||||||
|
}
|
||||||
|
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read remote branch %s: %v", branch, err)
|
||||||
|
}
|
||||||
|
if got := ref.Hash(); got == plumbing.ZeroHash {
|
||||||
|
t.Fatalf("remote branch %s hash = %s, want non-zero hash", branch, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertRemoteBranchDoesNotExist(t *testing.T, remoteDir, branch string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open remote repo: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false); err == nil {
|
||||||
|
t.Fatalf("remote branch %s exists, want missing", branch)
|
||||||
|
} else if err != plumbing.ErrReferenceNotFound {
|
||||||
|
t.Fatalf("read remote branch %s: %v", branch, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertRemoteBranchContents(t *testing.T, remoteDir, branch, wantContents string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
remoteRepo, err := git.PlainOpen(remoteDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open remote repo: %v", err)
|
||||||
|
}
|
||||||
|
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read remote branch %s: %v", branch, err)
|
||||||
|
}
|
||||||
|
commit, err := remoteRepo.CommitObject(ref.Hash())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read remote branch %s commit: %v", branch, err)
|
||||||
|
}
|
||||||
|
tree, err := commit.Tree()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read remote branch %s tree: %v", branch, err)
|
||||||
|
}
|
||||||
|
file, err := tree.File("branch.txt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read remote branch %s file: %v", branch, err)
|
||||||
|
}
|
||||||
|
contents, err := file.Contents()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read remote branch %s contents: %v", branch, err)
|
||||||
|
}
|
||||||
|
if contents != wantContents {
|
||||||
|
t.Fatalf("remote branch %s contents = %q, want %q", branch, contents, wantContents)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -595,6 +595,7 @@ func (s *ObjectTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Aut
|
|||||||
LastRefreshedAt: time.Time{},
|
LastRefreshedAt: time.Time{},
|
||||||
NextRefreshAfter: time.Time{},
|
NextRefreshAfter: time.Time{},
|
||||||
}
|
}
|
||||||
|
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -310,6 +310,7 @@ func (s *PostgresStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error)
|
|||||||
LastRefreshedAt: time.Time{},
|
LastRefreshedAt: time.Time{},
|
||||||
NextRefreshAfter: time.Time{},
|
NextRefreshAfter: time.Time{},
|
||||||
}
|
}
|
||||||
|
cliproxyauth.ApplyCustomHeadersFromMetadata(auth)
|
||||||
auths = append(auths, auth)
|
auths = append(auths, auth)
|
||||||
}
|
}
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
|
|||||||
@@ -330,32 +330,45 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reorder parts for 'model' role to ensure thinking block is first
|
// Reorder parts for 'model' role:
|
||||||
|
// 1. Thinking parts first (Antigravity API requirement)
|
||||||
|
// 2. Regular parts (text, inlineData, etc.)
|
||||||
|
// 3. FunctionCall parts last
|
||||||
|
//
|
||||||
|
// Moving functionCall parts to the end prevents tool_use↔tool_result
|
||||||
|
// pairing breakage: the Antigravity API internally splits model messages
|
||||||
|
// at functionCall boundaries. If a text part follows a functionCall, the
|
||||||
|
// split creates an extra assistant turn between tool_use and tool_result,
|
||||||
|
// which Claude rejects with "tool_use ids were found without tool_result
|
||||||
|
// blocks immediately after".
|
||||||
if role == "model" {
|
if role == "model" {
|
||||||
partsResult := gjson.GetBytes(clientContentJSON, "parts")
|
partsResult := gjson.GetBytes(clientContentJSON, "parts")
|
||||||
if partsResult.IsArray() {
|
if partsResult.IsArray() {
|
||||||
parts := partsResult.Array()
|
parts := partsResult.Array()
|
||||||
var thinkingParts []gjson.Result
|
if len(parts) > 1 {
|
||||||
var otherParts []gjson.Result
|
var thinkingParts []gjson.Result
|
||||||
for _, part := range parts {
|
var regularParts []gjson.Result
|
||||||
if part.Get("thought").Bool() {
|
var functionCallParts []gjson.Result
|
||||||
thinkingParts = append(thinkingParts, part)
|
for _, part := range parts {
|
||||||
} else {
|
if part.Get("thought").Bool() {
|
||||||
otherParts = append(otherParts, part)
|
thinkingParts = append(thinkingParts, part)
|
||||||
}
|
} else if part.Get("functionCall").Exists() {
|
||||||
}
|
functionCallParts = append(functionCallParts, part)
|
||||||
if len(thinkingParts) > 0 {
|
} else {
|
||||||
firstPartIsThinking := parts[0].Get("thought").Bool()
|
regularParts = append(regularParts, part)
|
||||||
if !firstPartIsThinking || len(thinkingParts) > 1 {
|
|
||||||
var newParts []interface{}
|
|
||||||
for _, p := range thinkingParts {
|
|
||||||
newParts = append(newParts, p.Value())
|
|
||||||
}
|
}
|
||||||
for _, p := range otherParts {
|
|
||||||
newParts = append(newParts, p.Value())
|
|
||||||
}
|
|
||||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
|
||||||
}
|
}
|
||||||
|
var newParts []interface{}
|
||||||
|
for _, p := range thinkingParts {
|
||||||
|
newParts = append(newParts, p.Value())
|
||||||
|
}
|
||||||
|
for _, p := range regularParts {
|
||||||
|
newParts = append(newParts, p.Value())
|
||||||
|
}
|
||||||
|
for _, p := range functionCallParts {
|
||||||
|
newParts = append(newParts, p.Value())
|
||||||
|
}
|
||||||
|
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -361,6 +361,167 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ReorderTextAfterFunctionCall(t *testing.T) {
|
||||||
|
// Bug: text part after tool_use in an assistant message causes Antigravity
|
||||||
|
// to split at functionCall boundary, creating an extra assistant turn that
|
||||||
|
// breaks tool_use↔tool_result adjacency (upstream issue #989).
|
||||||
|
// Fix: reorder parts so functionCall comes last.
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Let me check..."},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_abc",
|
||||||
|
"name": "Read",
|
||||||
|
"input": {"file": "test.go"}
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Reading the file now"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "call_abc",
|
||||||
|
"content": "file content"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("Expected 3 parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text parts should come before functionCall
|
||||||
|
if parts[0].Get("text").String() != "Let me check..." {
|
||||||
|
t.Errorf("Expected first text part first, got %s", parts[0].Raw)
|
||||||
|
}
|
||||||
|
if parts[1].Get("text").String() != "Reading the file now" {
|
||||||
|
t.Errorf("Expected second text part second, got %s", parts[1].Raw)
|
||||||
|
}
|
||||||
|
if !parts[2].Get("functionCall").Exists() {
|
||||||
|
t.Errorf("Expected functionCall last, got %s", parts[2].Raw)
|
||||||
|
}
|
||||||
|
if parts[2].Get("functionCall.name").String() != "Read" {
|
||||||
|
t.Errorf("Expected functionCall name 'Read', got '%s'", parts[2].Get("functionCall.name").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ReorderParallelFunctionCalls(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Reading both files."},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_1",
|
||||||
|
"name": "Read",
|
||||||
|
"input": {"file": "a.go"}
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "And this one too."},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_2",
|
||||||
|
"name": "Read",
|
||||||
|
"input": {"file": "b.go"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||||
|
if len(parts) != 4 {
|
||||||
|
t.Fatalf("Expected 4 parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
if parts[0].Get("text").String() != "Reading both files." {
|
||||||
|
t.Errorf("Expected first text, got %s", parts[0].Raw)
|
||||||
|
}
|
||||||
|
if parts[1].Get("text").String() != "And this one too." {
|
||||||
|
t.Errorf("Expected second text, got %s", parts[1].Raw)
|
||||||
|
}
|
||||||
|
if parts[2].Get("functionCall.name").String() != "Read" || parts[2].Get("functionCall.id").String() != "call_1" {
|
||||||
|
t.Errorf("Expected fc1 third, got %s", parts[2].Raw)
|
||||||
|
}
|
||||||
|
if parts[3].Get("functionCall.name").String() != "Read" || parts[3].Get("functionCall.id").String() != "call_2" {
|
||||||
|
t.Errorf("Expected fc2 fourth, got %s", parts[3].Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ReorderThinkingAndTextBeforeFunctionCall(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Let me think about this..."
|
||||||
|
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Hello"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Before thinking"},
|
||||||
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_xyz",
|
||||||
|
"name": "Bash",
|
||||||
|
"input": {"command": "ls"}
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "After tool call"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
// contents.1 = assistant message (contents.0 = user)
|
||||||
|
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
|
||||||
|
if len(parts) != 4 {
|
||||||
|
t.Fatalf("Expected 4 parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Order: thinking → text → text → functionCall
|
||||||
|
if !parts[0].Get("thought").Bool() {
|
||||||
|
t.Error("First part should be thinking")
|
||||||
|
}
|
||||||
|
if parts[1].Get("functionCall").Exists() || parts[1].Get("thought").Bool() {
|
||||||
|
t.Errorf("Second part should be text, got %s", parts[1].Raw)
|
||||||
|
}
|
||||||
|
if parts[2].Get("functionCall").Exists() || parts[2].Get("thought").Bool() {
|
||||||
|
t.Errorf("Third part should be text, got %s", parts[2].Raw)
|
||||||
|
}
|
||||||
|
if !parts[3].Get("functionCall").Exists() {
|
||||||
|
t.Errorf("Last part should be functionCall, got %s", parts[3].Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-3-5-sonnet-20240620",
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
|||||||
@@ -26,6 +26,11 @@ type ConvertCodexResponseToClaudeParams struct {
|
|||||||
HasToolCall bool
|
HasToolCall bool
|
||||||
BlockIndex int
|
BlockIndex int
|
||||||
HasReceivedArgumentsDelta bool
|
HasReceivedArgumentsDelta bool
|
||||||
|
HasTextDelta bool
|
||||||
|
TextBlockOpen bool
|
||||||
|
ThinkingBlockOpen bool
|
||||||
|
ThinkingStopPending bool
|
||||||
|
ThinkingSignature string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
||||||
@@ -44,7 +49,7 @@ type ConvertCodexResponseToClaudeParams struct {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - [][]byte: A slice of Claude Code-compatible JSON responses
|
// - [][]byte: A slice of Claude Code-compatible JSON responses
|
||||||
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, param *any) [][]byte {
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &ConvertCodexResponseToClaudeParams{
|
*param = &ConvertCodexResponseToClaudeParams{
|
||||||
HasToolCall: false,
|
HasToolCall: false,
|
||||||
@@ -52,7 +57,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// log.Debugf("rawJSON: %s", string(rawJSON))
|
|
||||||
if !bytes.HasPrefix(rawJSON, dataTag) {
|
if !bytes.HasPrefix(rawJSON, dataTag) {
|
||||||
return [][]byte{}
|
return [][]byte{}
|
||||||
}
|
}
|
||||||
@@ -60,9 +64,18 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
|
|
||||||
output := make([]byte, 0, 512)
|
output := make([]byte, 0, 512)
|
||||||
rootResult := gjson.ParseBytes(rawJSON)
|
rootResult := gjson.ParseBytes(rawJSON)
|
||||||
|
params := (*param).(*ConvertCodexResponseToClaudeParams)
|
||||||
|
if params.ThinkingBlockOpen && params.ThinkingStopPending {
|
||||||
|
switch rootResult.Get("type").String() {
|
||||||
|
case "response.content_part.added", "response.completed":
|
||||||
|
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
typeResult := rootResult.Get("type")
|
typeResult := rootResult.Get("type")
|
||||||
typeStr := typeResult.String()
|
typeStr := typeResult.String()
|
||||||
var template []byte
|
var template []byte
|
||||||
|
|
||||||
if typeStr == "response.created" {
|
if typeStr == "response.created" {
|
||||||
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
|
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
|
||||||
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
|
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
|
||||||
@@ -70,43 +83,49 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
|
||||||
} else if typeStr == "response.reasoning_summary_part.added" {
|
} else if typeStr == "response.reasoning_summary_part.added" {
|
||||||
|
if params.ThinkingBlockOpen && params.ThinkingStopPending {
|
||||||
|
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||||
|
}
|
||||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
|
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
|
params.ThinkingBlockOpen = true
|
||||||
|
params.ThinkingStopPending = false
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||||
} else if typeStr == "response.reasoning_summary_text.delta" {
|
} else if typeStr == "response.reasoning_summary_text.delta" {
|
||||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
|
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
|
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||||
} else if typeStr == "response.reasoning_summary_part.done" {
|
} else if typeStr == "response.reasoning_summary_part.done" {
|
||||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
params.ThinkingStopPending = true
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
if params.ThinkingSignature != "" {
|
||||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||||
|
}
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
|
||||||
|
|
||||||
} else if typeStr == "response.content_part.added" {
|
} else if typeStr == "response.content_part.added" {
|
||||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
|
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
|
params.TextBlockOpen = true
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||||
} else if typeStr == "response.output_text.delta" {
|
} else if typeStr == "response.output_text.delta" {
|
||||||
|
params.HasTextDelta = true
|
||||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
|
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
|
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||||
} else if typeStr == "response.content_part.done" {
|
} else if typeStr == "response.content_part.done" {
|
||||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
params.TextBlockOpen = false
|
||||||
|
params.BlockIndex++
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||||
} else if typeStr == "response.completed" {
|
} else if typeStr == "response.completed" {
|
||||||
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
|
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
|
||||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
p := params.HasToolCall
|
||||||
stopReason := rootResult.Get("response.stop_reason").String()
|
stopReason := rootResult.Get("response.stop_reason").String()
|
||||||
if p {
|
if p {
|
||||||
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
|
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
|
||||||
@@ -128,13 +147,13 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
itemResult := rootResult.Get("item")
|
itemResult := rootResult.Get("item")
|
||||||
itemType := itemResult.Get("type").String()
|
itemType := itemResult.Get("type").String()
|
||||||
if itemType == "function_call" {
|
if itemType == "function_call" {
|
||||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
params.HasToolCall = true
|
||||||
|
params.HasReceivedArgumentsDelta = false
|
||||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
|
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||||
{
|
{
|
||||||
// Restore original tool name if shortened
|
|
||||||
name := itemResult.Get("name").String()
|
name := itemResult.Get("name").String()
|
||||||
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
||||||
if orig, ok := rev[name]; ok {
|
if orig, ok := rev[name]; ok {
|
||||||
@@ -146,37 +165,85 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||||
|
|
||||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||||
|
} else if itemType == "reasoning" {
|
||||||
|
params.ThinkingSignature = itemResult.Get("encrypted_content").String()
|
||||||
|
if params.ThinkingStopPending {
|
||||||
|
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if typeStr == "response.output_item.done" {
|
} else if typeStr == "response.output_item.done" {
|
||||||
itemResult := rootResult.Get("item")
|
itemResult := rootResult.Get("item")
|
||||||
itemType := itemResult.Get("type").String()
|
itemType := itemResult.Get("type").String()
|
||||||
if itemType == "function_call" {
|
if itemType == "message" {
|
||||||
|
if params.HasTextDelta {
|
||||||
|
return [][]byte{output}
|
||||||
|
}
|
||||||
|
contentResult := itemResult.Get("content")
|
||||||
|
if !contentResult.Exists() || !contentResult.IsArray() {
|
||||||
|
return [][]byte{output}
|
||||||
|
}
|
||||||
|
var textBuilder strings.Builder
|
||||||
|
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("type").String() != "output_text" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if txt := part.Get("text").String(); txt != "" {
|
||||||
|
textBuilder.WriteString(txt)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
text := textBuilder.String()
|
||||||
|
if text == "" {
|
||||||
|
return [][]byte{output}
|
||||||
|
}
|
||||||
|
|
||||||
|
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||||
|
if !params.TextBlockOpen {
|
||||||
|
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
|
||||||
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
|
params.TextBlockOpen = true
|
||||||
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
|
||||||
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
|
template, _ = sjson.SetBytes(template, "delta.text", text)
|
||||||
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||||
|
|
||||||
template = []byte(`{"type":"content_block_stop","index":0}`)
|
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
params.TextBlockOpen = false
|
||||||
|
params.BlockIndex++
|
||||||
|
params.HasTextDelta = true
|
||||||
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||||
|
} else if itemType == "function_call" {
|
||||||
|
template = []byte(`{"type":"content_block_stop","index":0}`)
|
||||||
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
|
params.BlockIndex++
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
|
||||||
|
} else if itemType == "reasoning" {
|
||||||
|
if signature := itemResult.Get("encrypted_content").String(); signature != "" {
|
||||||
|
params.ThinkingSignature = signature
|
||||||
|
}
|
||||||
|
output = append(output, finalizeCodexThinkingBlock(params)...)
|
||||||
|
params.ThinkingSignature = ""
|
||||||
}
|
}
|
||||||
} else if typeStr == "response.function_call_arguments.delta" {
|
} else if typeStr == "response.function_call_arguments.delta" {
|
||||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
|
params.HasReceivedArgumentsDelta = true
|
||||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
|
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||||
} else if typeStr == "response.function_call_arguments.done" {
|
} else if typeStr == "response.function_call_arguments.done" {
|
||||||
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
|
if !params.HasReceivedArgumentsDelta {
|
||||||
// in a single "done" event without preceding "delta" events.
|
|
||||||
// Emit the full arguments as a single input_json_delta so the
|
|
||||||
// downstream Claude client receives the complete tool input.
|
|
||||||
// When delta events were already received, skip to avoid duplicating arguments.
|
|
||||||
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
|
|
||||||
if args := rootResult.Get("arguments").String(); args != "" {
|
if args := rootResult.Get("arguments").String(); args != "" {
|
||||||
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
|
||||||
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||||
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
|
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
|
||||||
|
|
||||||
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
|
||||||
@@ -191,15 +258,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
// This function processes the complete Codex response and transforms it into a single Claude Code-compatible
|
// This function processes the complete Codex response and transforms it into a single Claude Code-compatible
|
||||||
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
|
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
|
||||||
// the information into a single response that matches the Claude Code API format.
|
// the information into a single response that matches the Claude Code API format.
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// - ctx: The context for the request, used for cancellation and timeout handling
|
|
||||||
// - modelName: The name of the model being used for the response (unused in current implementation)
|
|
||||||
// - rawJSON: The raw JSON response from the Codex API
|
|
||||||
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
|
|
||||||
//
|
|
||||||
// Returns:
|
|
||||||
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
|
|
||||||
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
|
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
|
||||||
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
||||||
|
|
||||||
@@ -230,6 +288,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
|||||||
switch item.Get("type").String() {
|
switch item.Get("type").String() {
|
||||||
case "reasoning":
|
case "reasoning":
|
||||||
thinkingBuilder := strings.Builder{}
|
thinkingBuilder := strings.Builder{}
|
||||||
|
signature := item.Get("encrypted_content").String()
|
||||||
if summary := item.Get("summary"); summary.Exists() {
|
if summary := item.Get("summary"); summary.Exists() {
|
||||||
if summary.IsArray() {
|
if summary.IsArray() {
|
||||||
summary.ForEach(func(_, part gjson.Result) bool {
|
summary.ForEach(func(_, part gjson.Result) bool {
|
||||||
@@ -260,9 +319,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if thinkingBuilder.Len() > 0 {
|
if thinkingBuilder.Len() > 0 || signature != "" {
|
||||||
block := []byte(`{"type":"thinking","thinking":""}`)
|
block := []byte(`{"type":"thinking","thinking":""}`)
|
||||||
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
|
||||||
|
if signature != "" {
|
||||||
|
block, _ = sjson.SetBytes(block, "signature", signature)
|
||||||
|
}
|
||||||
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
out, _ = sjson.SetRawBytes(out, "content.-1", block)
|
||||||
}
|
}
|
||||||
case "message":
|
case "message":
|
||||||
@@ -371,6 +433,30 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
|
|||||||
return rev
|
return rev
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
|
func ClaudeTokenCount(_ context.Context, count int64) []byte {
|
||||||
return translatorcommon.ClaudeInputTokensJSON(count)
|
return translatorcommon.ClaudeInputTokensJSON(count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []byte {
|
||||||
|
if !params.ThinkingBlockOpen {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
output := make([]byte, 0, 256)
|
||||||
|
if params.ThinkingSignature != "" {
|
||||||
|
signatureDelta := []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":""}}`)
|
||||||
|
signatureDelta, _ = sjson.SetBytes(signatureDelta, "index", params.BlockIndex)
|
||||||
|
signatureDelta, _ = sjson.SetBytes(signatureDelta, "delta.signature", params.ThinkingSignature)
|
||||||
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", signatureDelta, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBlockStop := []byte(`{"type":"content_block_stop","index":0}`)
|
||||||
|
contentBlockStop, _ = sjson.SetBytes(contentBlockStop, "index", params.BlockIndex)
|
||||||
|
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", contentBlockStop, 2)
|
||||||
|
|
||||||
|
params.BlockIndex++
|
||||||
|
params.ThinkingBlockOpen = false
|
||||||
|
params.ThinkingStopPending = false
|
||||||
|
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|||||||
319
internal/translator/codex/claude/codex_claude_response_test.go
Normal file
319
internal/translator/codex/claude/codex_claude_response_test.go
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToClaude_StreamThinkingIncludesSignature(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
originalRequest := []byte(`{"messages":[]}`)
|
||||||
|
var param any
|
||||||
|
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"model\":\"gpt-5\"}}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_123\"}}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs [][]byte
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
startFound := false
|
||||||
|
signatureDeltaFound := false
|
||||||
|
stopFound := false
|
||||||
|
|
||||||
|
for _, out := range outputs {
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||||
|
switch data.Get("type").String() {
|
||||||
|
case "content_block_start":
|
||||||
|
if data.Get("content_block.type").String() == "thinking" {
|
||||||
|
startFound = true
|
||||||
|
if data.Get("content_block.signature").Exists() {
|
||||||
|
t.Fatalf("thinking start block should NOT have signature field when signature is unknown: %s", line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "content_block_delta":
|
||||||
|
if data.Get("delta.type").String() == "signature_delta" {
|
||||||
|
signatureDeltaFound = true
|
||||||
|
if got := data.Get("delta.signature").String(); got != "enc_sig_123" {
|
||||||
|
t.Fatalf("unexpected signature delta: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "content_block_stop":
|
||||||
|
stopFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !startFound {
|
||||||
|
t.Fatal("expected thinking content_block_start event")
|
||||||
|
}
|
||||||
|
if !signatureDeltaFound {
|
||||||
|
t.Fatal("expected signature_delta event for thinking block")
|
||||||
|
}
|
||||||
|
if !stopFound {
|
||||||
|
t.Fatal("expected content_block_stop event for thinking block")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToClaude_StreamThinkingWithoutReasoningItemStillIncludesSignatureField(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
originalRequest := []byte(`{"messages":[]}`)
|
||||||
|
var param any
|
||||||
|
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs [][]byte
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
thinkingStartFound := false
|
||||||
|
thinkingStopFound := false
|
||||||
|
signatureDeltaFound := false
|
||||||
|
|
||||||
|
for _, out := range outputs {
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||||
|
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
|
||||||
|
thinkingStartFound = true
|
||||||
|
if data.Get("content_block.signature").Exists() {
|
||||||
|
t.Fatalf("thinking start block should NOT have signature field without encrypted_content: %s", line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data.Get("type").String() == "content_block_stop" && data.Get("index").Int() == 0 {
|
||||||
|
thinkingStopFound = true
|
||||||
|
}
|
||||||
|
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
|
||||||
|
signatureDeltaFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !thinkingStartFound {
|
||||||
|
t.Fatal("expected thinking content_block_start event")
|
||||||
|
}
|
||||||
|
if !thinkingStopFound {
|
||||||
|
t.Fatal("expected thinking content_block_stop event")
|
||||||
|
}
|
||||||
|
if signatureDeltaFound {
|
||||||
|
t.Fatal("did not expect signature_delta without encrypted_content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToClaude_StreamThinkingFinalizesPendingBlockBeforeNextSummaryPart(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
originalRequest := []byte(`{"messages":[]}`)
|
||||||
|
var param any
|
||||||
|
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs [][]byte
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
startCount := 0
|
||||||
|
stopCount := 0
|
||||||
|
for _, out := range outputs {
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||||
|
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
|
||||||
|
startCount++
|
||||||
|
}
|
||||||
|
if data.Get("type").String() == "content_block_stop" {
|
||||||
|
stopCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startCount != 2 {
|
||||||
|
t.Fatalf("expected 2 thinking block starts, got %d", startCount)
|
||||||
|
}
|
||||||
|
if stopCount != 1 {
|
||||||
|
t.Fatalf("expected pending thinking block to be finalized before second start, got %d stops", stopCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToClaude_StreamThinkingRetainsSignatureAcrossMultipartReasoning(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
originalRequest := []byte(`{"messages":[]}`)
|
||||||
|
var param any
|
||||||
|
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_multipart\"}}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Second part\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs [][]byte
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
signatureDeltaCount := 0
|
||||||
|
for _, out := range outputs {
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||||
|
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
|
||||||
|
signatureDeltaCount++
|
||||||
|
if got := data.Get("delta.signature").String(); got != "enc_sig_multipart" {
|
||||||
|
t.Fatalf("unexpected signature delta: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if signatureDeltaCount != 2 {
|
||||||
|
t.Fatalf("expected signature_delta for both multipart thinking blocks, got %d", signatureDeltaCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToClaude_StreamThinkingUsesEarlyCapturedSignatureWhenDoneOmitsIt(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
originalRequest := []byte(`{"messages":[]}`)
|
||||||
|
var param any
|
||||||
|
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_early\"}}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
|
||||||
|
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs [][]byte
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
signatureDeltaCount := 0
|
||||||
|
for _, out := range outputs {
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||||
|
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
|
||||||
|
signatureDeltaCount++
|
||||||
|
if got := data.Get("delta.signature").String(); got != "enc_sig_early" {
|
||||||
|
t.Fatalf("unexpected signature delta: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if signatureDeltaCount != 1 {
|
||||||
|
t.Fatalf("expected signature_delta from early-captured signature, got %d", signatureDeltaCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToClaudeNonStream_ThinkingIncludesSignature(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
originalRequest := []byte(`{"messages":[]}`)
|
||||||
|
response := []byte(`{
|
||||||
|
"type":"response.completed",
|
||||||
|
"response":{
|
||||||
|
"id":"resp_123",
|
||||||
|
"model":"gpt-5",
|
||||||
|
"usage":{"input_tokens":10,"output_tokens":20},
|
||||||
|
"output":[
|
||||||
|
{
|
||||||
|
"type":"reasoning",
|
||||||
|
"encrypted_content":"enc_sig_nonstream",
|
||||||
|
"summary":[{"type":"summary_text","text":"internal reasoning"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type":"message",
|
||||||
|
"content":[{"type":"output_text","text":"final answer"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out := ConvertCodexResponseToClaudeNonStream(ctx, "", originalRequest, nil, response, nil)
|
||||||
|
parsed := gjson.ParseBytes(out)
|
||||||
|
|
||||||
|
thinking := parsed.Get("content.0")
|
||||||
|
if thinking.Get("type").String() != "thinking" {
|
||||||
|
t.Fatalf("expected first content block to be thinking, got %s", thinking.Raw)
|
||||||
|
}
|
||||||
|
if got := thinking.Get("signature").String(); got != "enc_sig_nonstream" {
|
||||||
|
t.Fatalf("expected signature to be preserved, got %q", got)
|
||||||
|
}
|
||||||
|
if got := thinking.Get("thinking").String(); got != "internal reasoning" {
|
||||||
|
t.Fatalf("unexpected thinking text: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToClaude_StreamEmptyOutputUsesOutputItemDoneMessageFallback(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
originalRequest := []byte(`{"tools":[]}`)
|
||||||
|
var param any
|
||||||
|
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5\"}}"),
|
||||||
|
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}"),
|
||||||
|
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs [][]byte
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
foundText := false
|
||||||
|
for _, out := range outputs {
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||||
|
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "text_delta" && data.Get("delta.text").String() == "ok" {
|
||||||
|
foundText = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundText {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundText {
|
||||||
|
t.Fatalf("expected fallback content from response.output_item.done message; outputs=%q", outputs)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -20,10 +20,11 @@ var (
|
|||||||
|
|
||||||
// ConvertCodexResponseToGeminiParams holds parameters for response conversion.
|
// ConvertCodexResponseToGeminiParams holds parameters for response conversion.
|
||||||
type ConvertCodexResponseToGeminiParams struct {
|
type ConvertCodexResponseToGeminiParams struct {
|
||||||
Model string
|
Model string
|
||||||
CreatedAt int64
|
CreatedAt int64
|
||||||
ResponseID string
|
ResponseID string
|
||||||
LastStorageOutput []byte
|
LastStorageOutput []byte
|
||||||
|
HasOutputTextDelta bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
|
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
|
||||||
@@ -42,10 +43,11 @@ type ConvertCodexResponseToGeminiParams struct {
|
|||||||
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &ConvertCodexResponseToGeminiParams{
|
*param = &ConvertCodexResponseToGeminiParams{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
CreatedAt: 0,
|
CreatedAt: 0,
|
||||||
ResponseID: "",
|
ResponseID: "",
|
||||||
LastStorageOutput: nil,
|
LastStorageOutput: nil,
|
||||||
|
HasOutputTextDelta: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,18 +60,18 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
|||||||
typeResult := rootResult.Get("type")
|
typeResult := rootResult.Get("type")
|
||||||
typeStr := typeResult.String()
|
typeStr := typeResult.String()
|
||||||
|
|
||||||
|
params := (*param).(*ConvertCodexResponseToGeminiParams)
|
||||||
|
|
||||||
// Base Gemini response template
|
// Base Gemini response template
|
||||||
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`)
|
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`)
|
||||||
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 && typeStr == "response.output_item.done" {
|
{
|
||||||
template = append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...)
|
template, _ = sjson.SetBytes(template, "modelVersion", params.Model)
|
||||||
} else {
|
|
||||||
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
|
|
||||||
createdAtResult := rootResult.Get("response.created_at")
|
createdAtResult := rootResult.Get("response.created_at")
|
||||||
if createdAtResult.Exists() {
|
if createdAtResult.Exists() {
|
||||||
(*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int()
|
params.CreatedAt = createdAtResult.Int()
|
||||||
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
|
template, _ = sjson.SetBytes(template, "createTime", time.Unix(params.CreatedAt, 0).Format(time.RFC3339Nano))
|
||||||
}
|
}
|
||||||
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
|
template, _ = sjson.SetBytes(template, "responseId", params.ResponseID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle function call completion
|
// Handle function call completion
|
||||||
@@ -101,7 +103,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
|||||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
|
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
|
||||||
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
|
||||||
|
|
||||||
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
|
params.LastStorageOutput = append([]byte(nil), template...)
|
||||||
|
|
||||||
// Use this return to storage message
|
// Use this return to storage message
|
||||||
return [][]byte{}
|
return [][]byte{}
|
||||||
@@ -111,15 +113,45 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
|||||||
if typeStr == "response.created" { // Handle response creation - set model and response ID
|
if typeStr == "response.created" { // Handle response creation - set model and response ID
|
||||||
template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String())
|
template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String())
|
||||||
template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String())
|
template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String())
|
||||||
(*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String()
|
params.ResponseID = rootResult.Get("response.id").String()
|
||||||
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
|
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
|
||||||
part := []byte(`{"thought":true,"text":""}`)
|
part := []byte(`{"thought":true,"text":""}`)
|
||||||
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
|
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
|
||||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||||
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta
|
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta
|
||||||
|
params.HasOutputTextDelta = true
|
||||||
part := []byte(`{"text":""}`)
|
part := []byte(`{"text":""}`)
|
||||||
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
|
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
|
||||||
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||||
|
} else if typeStr == "response.output_item.done" { // Fallback: emit final message text when no delta chunks were received
|
||||||
|
itemResult := rootResult.Get("item")
|
||||||
|
if itemResult.Get("type").String() != "message" || params.HasOutputTextDelta {
|
||||||
|
return [][]byte{}
|
||||||
|
}
|
||||||
|
contentResult := itemResult.Get("content")
|
||||||
|
if !contentResult.Exists() || !contentResult.IsArray() {
|
||||||
|
return [][]byte{}
|
||||||
|
}
|
||||||
|
wroteText := false
|
||||||
|
contentResult.ForEach(func(_, partResult gjson.Result) bool {
|
||||||
|
if partResult.Get("type").String() != "output_text" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
text := partResult.Get("text").String()
|
||||||
|
if text == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
part := []byte(`{"text":""}`)
|
||||||
|
part, _ = sjson.SetBytes(part, "text", text)
|
||||||
|
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
|
||||||
|
wroteText = true
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if wroteText {
|
||||||
|
params.HasOutputTextDelta = true
|
||||||
|
return [][]byte{template}
|
||||||
|
}
|
||||||
|
return [][]byte{}
|
||||||
} else if typeStr == "response.completed" { // Handle response completion with usage metadata
|
} else if typeStr == "response.completed" { // Handle response completion with usage metadata
|
||||||
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
|
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
|
||||||
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
|
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
|
||||||
@@ -129,11 +161,10 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
|
|||||||
return [][]byte{}
|
return [][]byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 {
|
if len(params.LastStorageOutput) > 0 {
|
||||||
return [][]byte{
|
stored := append([]byte(nil), params.LastStorageOutput...)
|
||||||
append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...),
|
params.LastStorageOutput = nil
|
||||||
template,
|
return [][]byte{stored, template}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return [][]byte{template}
|
return [][]byte{template}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertCodexResponseToGemini_StreamEmptyOutputUsesOutputItemDoneMessageFallback(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
originalRequest := []byte(`{"tools":[]}`)
|
||||||
|
var param any
|
||||||
|
|
||||||
|
chunks := [][]byte{
|
||||||
|
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}"),
|
||||||
|
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs [][]byte
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
outputs = append(outputs, ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, chunk, ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, out := range outputs {
|
||||||
|
if gjson.GetBytes(out, "candidates.0.content.parts.0.text").String() == "ok" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected fallback content from response.output_item.done message; outputs=%q", outputs)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -284,12 +284,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the output array for content and function calls
|
// Process the output array for content and function calls
|
||||||
|
var toolCalls [][]byte
|
||||||
outputResult := responseResult.Get("output")
|
outputResult := responseResult.Get("output")
|
||||||
if outputResult.IsArray() {
|
if outputResult.IsArray() {
|
||||||
outputArray := outputResult.Array()
|
outputArray := outputResult.Array()
|
||||||
var contentText string
|
var contentText string
|
||||||
var reasoningText string
|
var reasoningText string
|
||||||
var toolCalls [][]byte
|
|
||||||
|
|
||||||
for _, outputItem := range outputArray {
|
for _, outputItem := range outputArray {
|
||||||
outputType := outputItem.Get("type").String()
|
outputType := outputItem.Get("type").String()
|
||||||
@@ -367,8 +367,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
|||||||
if statusResult := responseResult.Get("status"); statusResult.Exists() {
|
if statusResult := responseResult.Get("status"); statusResult.Exists() {
|
||||||
status := statusResult.String()
|
status := statusResult.String()
|
||||||
if status == "completed" {
|
if status == "completed" {
|
||||||
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
|
finishReason := "stop"
|
||||||
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
|
if len(toolCalls) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
|
||||||
|
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
package claude
|
package claude
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
@@ -31,8 +31,6 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
|
|||||||
// - []byte: The transformed request in Gemini CLI format.
|
// - []byte: The transformed request in Gemini CLI format.
|
||||||
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||||
rawJSON := inputRawJSON
|
rawJSON := inputRawJSON
|
||||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
|
||||||
|
|
||||||
// Build output Gemini CLI request JSON
|
// Build output Gemini CLI request JSON
|
||||||
out := []byte(`{"contents":[]}`)
|
out := []byte(`{"contents":[]}`)
|
||||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||||
@@ -146,13 +144,37 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// strip trailing model turn with unanswered function calls —
|
||||||
|
// Gemini returns empty responses when the last turn is a model
|
||||||
|
// functionCall with no corresponding user functionResponse.
|
||||||
|
contents := gjson.GetBytes(out, "contents")
|
||||||
|
if contents.Exists() && contents.IsArray() {
|
||||||
|
arr := contents.Array()
|
||||||
|
if len(arr) > 0 {
|
||||||
|
last := arr[len(arr)-1]
|
||||||
|
if last.Get("role").String() == "model" {
|
||||||
|
hasFC := false
|
||||||
|
last.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("functionCall").Exists() {
|
||||||
|
hasFC = true
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if hasFC {
|
||||||
|
out, _ = sjson.DeleteBytes(out, fmt.Sprintf("contents.%d", len(arr)-1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// tools
|
// tools
|
||||||
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
|
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
|
||||||
hasTools := false
|
hasTools := false
|
||||||
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
|
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
|
||||||
inputSchemaResult := toolResult.Get("input_schema")
|
inputSchemaResult := toolResult.Get("input_schema")
|
||||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||||
inputSchema := inputSchemaResult.Raw
|
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
|
||||||
tool := []byte(toolResult.Raw)
|
tool := []byte(toolResult.Raw)
|
||||||
var err error
|
var err error
|
||||||
tool, err = sjson.DeleteBytes(tool, "input_schema")
|
tool, err = sjson.DeleteBytes(tool, "input_schema")
|
||||||
@@ -168,6 +190,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
tool, _ = sjson.DeleteBytes(tool, "type")
|
tool, _ = sjson.DeleteBytes(tool, "type")
|
||||||
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
||||||
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
||||||
|
tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming")
|
||||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||||
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
|
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
|
||||||
if !hasTools {
|
if !hasTools {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,29 +17,35 @@ import (
|
|||||||
type oaiToResponsesStateReasoning struct {
|
type oaiToResponsesStateReasoning struct {
|
||||||
ReasoningID string
|
ReasoningID string
|
||||||
ReasoningData string
|
ReasoningData string
|
||||||
|
OutputIndex int
|
||||||
}
|
}
|
||||||
type oaiToResponsesState struct {
|
type oaiToResponsesState struct {
|
||||||
Seq int
|
Seq int
|
||||||
ResponseID string
|
ResponseID string
|
||||||
Created int64
|
Created int64
|
||||||
Started bool
|
Started bool
|
||||||
ReasoningID string
|
CompletionPending bool
|
||||||
ReasoningIndex int
|
CompletedEmitted bool
|
||||||
|
ReasoningID string
|
||||||
|
ReasoningIndex int
|
||||||
// aggregation buffers for response.output
|
// aggregation buffers for response.output
|
||||||
// Per-output message text buffers by index
|
// Per-output message text buffers by index
|
||||||
MsgTextBuf map[int]*strings.Builder
|
MsgTextBuf map[int]*strings.Builder
|
||||||
ReasoningBuf strings.Builder
|
ReasoningBuf strings.Builder
|
||||||
Reasonings []oaiToResponsesStateReasoning
|
Reasonings []oaiToResponsesStateReasoning
|
||||||
FuncArgsBuf map[int]*strings.Builder // index -> args
|
FuncArgsBuf map[string]*strings.Builder
|
||||||
FuncNames map[int]string // index -> name
|
FuncNames map[string]string
|
||||||
FuncCallIDs map[int]string // index -> call_id
|
FuncCallIDs map[string]string
|
||||||
|
FuncOutputIx map[string]int
|
||||||
|
MsgOutputIx map[int]int
|
||||||
|
NextOutputIx int
|
||||||
// message item state per output index
|
// message item state per output index
|
||||||
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
|
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
|
||||||
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
|
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
|
||||||
MsgItemDone map[int]bool // whether message done events were emitted
|
MsgItemDone map[int]bool // whether message done events were emitted
|
||||||
// function item done state
|
// function item done state
|
||||||
FuncArgsDone map[int]bool
|
FuncArgsDone map[string]bool
|
||||||
FuncItemDone map[int]bool
|
FuncItemDone map[string]bool
|
||||||
// usage aggregation
|
// usage aggregation
|
||||||
PromptTokens int64
|
PromptTokens int64
|
||||||
CachedTokens int64
|
CachedTokens int64
|
||||||
@@ -55,20 +62,157 @@ func emitRespEvent(event string, payload []byte) []byte {
|
|||||||
return translatorcommon.SSEEventData(event, payload)
|
return translatorcommon.SSEEventData(event, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildResponsesCompletedEvent(st *oaiToResponsesState, requestRawJSON []byte, nextSeq func() int) []byte {
|
||||||
|
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||||
|
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
|
||||||
|
// Inject original request fields into response as per docs/response.completed.json
|
||||||
|
if requestRawJSON != nil {
|
||||||
|
req := gjson.ParseBytes(requestRawJSON)
|
||||||
|
if v := req.Get("instructions"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
|
||||||
|
}
|
||||||
|
if v := req.Get("max_output_tokens"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
|
||||||
|
}
|
||||||
|
if v := req.Get("max_tool_calls"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
|
||||||
|
}
|
||||||
|
if v := req.Get("model"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
|
||||||
|
}
|
||||||
|
if v := req.Get("parallel_tool_calls"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
|
||||||
|
}
|
||||||
|
if v := req.Get("previous_response_id"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
|
||||||
|
}
|
||||||
|
if v := req.Get("prompt_cache_key"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
|
||||||
|
}
|
||||||
|
if v := req.Get("reasoning"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
|
||||||
|
}
|
||||||
|
if v := req.Get("safety_identifier"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
|
||||||
|
}
|
||||||
|
if v := req.Get("service_tier"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
|
||||||
|
}
|
||||||
|
if v := req.Get("store"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
|
||||||
|
}
|
||||||
|
if v := req.Get("temperature"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
|
||||||
|
}
|
||||||
|
if v := req.Get("text"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
|
||||||
|
}
|
||||||
|
if v := req.Get("tool_choice"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
|
||||||
|
}
|
||||||
|
if v := req.Get("tools"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
|
||||||
|
}
|
||||||
|
if v := req.Get("top_logprobs"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
|
||||||
|
}
|
||||||
|
if v := req.Get("top_p"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
|
||||||
|
}
|
||||||
|
if v := req.Get("truncation"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
|
||||||
|
}
|
||||||
|
if v := req.Get("user"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
|
||||||
|
}
|
||||||
|
if v := req.Get("metadata"); v.Exists() {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outputsWrapper := []byte(`{"arr":[]}`)
|
||||||
|
type completedOutputItem struct {
|
||||||
|
index int
|
||||||
|
raw []byte
|
||||||
|
}
|
||||||
|
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
|
||||||
|
if len(st.Reasonings) > 0 {
|
||||||
|
for _, r := range st.Reasonings {
|
||||||
|
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||||
|
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
|
||||||
|
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
|
||||||
|
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(st.MsgItemAdded) > 0 {
|
||||||
|
for i := range st.MsgItemAdded {
|
||||||
|
txt := ""
|
||||||
|
if b := st.MsgTextBuf[i]; b != nil {
|
||||||
|
txt = b.String()
|
||||||
|
}
|
||||||
|
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||||
|
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||||
|
item, _ = sjson.SetBytes(item, "content.0.text", txt)
|
||||||
|
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(st.FuncArgsBuf) > 0 {
|
||||||
|
for key := range st.FuncArgsBuf {
|
||||||
|
args := ""
|
||||||
|
if b := st.FuncArgsBuf[key]; b != nil {
|
||||||
|
args = b.String()
|
||||||
|
}
|
||||||
|
callID := st.FuncCallIDs[key]
|
||||||
|
name := st.FuncNames[key]
|
||||||
|
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||||
|
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||||
|
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||||
|
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||||
|
item, _ = sjson.SetBytes(item, "name", name)
|
||||||
|
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
|
||||||
|
for _, item := range outputItems {
|
||||||
|
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||||
|
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||||
|
}
|
||||||
|
if st.UsageSeen {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
|
||||||
|
if st.ReasoningTokens > 0 {
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
|
||||||
|
}
|
||||||
|
total := st.TotalTokens
|
||||||
|
if total == 0 {
|
||||||
|
total = st.PromptTokens + st.CompletionTokens
|
||||||
|
}
|
||||||
|
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
|
||||||
|
}
|
||||||
|
return emitRespEvent("response.completed", completed)
|
||||||
|
}
|
||||||
|
|
||||||
// ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
// ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
||||||
// to OpenAI Responses SSE events (response.*).
|
// to OpenAI Responses SSE events (response.*).
|
||||||
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &oaiToResponsesState{
|
*param = &oaiToResponsesState{
|
||||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
FuncArgsBuf: make(map[string]*strings.Builder),
|
||||||
FuncNames: make(map[int]string),
|
FuncNames: make(map[string]string),
|
||||||
FuncCallIDs: make(map[int]string),
|
FuncCallIDs: make(map[string]string),
|
||||||
|
FuncOutputIx: make(map[string]int),
|
||||||
|
MsgOutputIx: make(map[int]int),
|
||||||
MsgTextBuf: make(map[int]*strings.Builder),
|
MsgTextBuf: make(map[int]*strings.Builder),
|
||||||
MsgItemAdded: make(map[int]bool),
|
MsgItemAdded: make(map[int]bool),
|
||||||
MsgContentAdded: make(map[int]bool),
|
MsgContentAdded: make(map[int]bool),
|
||||||
MsgItemDone: make(map[int]bool),
|
MsgItemDone: make(map[int]bool),
|
||||||
FuncArgsDone: make(map[int]bool),
|
FuncArgsDone: make(map[string]bool),
|
||||||
FuncItemDone: make(map[int]bool),
|
FuncItemDone: make(map[string]bool),
|
||||||
Reasonings: make([]oaiToResponsesStateReasoning, 0),
|
Reasonings: make([]oaiToResponsesStateReasoning, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -83,6 +227,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
return [][]byte{}
|
return [][]byte{}
|
||||||
}
|
}
|
||||||
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
if bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||||
|
if st.CompletionPending && !st.CompletedEmitted {
|
||||||
|
st.CompletedEmitted = true
|
||||||
|
return [][]byte{buildResponsesCompletedEvent(st, requestRawJSON, func() int { st.Seq++; return st.Seq })}
|
||||||
|
}
|
||||||
return [][]byte{}
|
return [][]byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,6 +273,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||||
|
allocOutputIndex := func() int {
|
||||||
|
ix := st.NextOutputIx
|
||||||
|
st.NextOutputIx++
|
||||||
|
return ix
|
||||||
|
}
|
||||||
|
toolStateKey := func(outputIndex, toolIndex int) string { return fmt.Sprintf("%d:%d", outputIndex, toolIndex) }
|
||||||
var out [][]byte
|
var out [][]byte
|
||||||
|
|
||||||
if !st.Started {
|
if !st.Started {
|
||||||
@@ -135,20 +289,25 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
st.ReasoningBuf.Reset()
|
st.ReasoningBuf.Reset()
|
||||||
st.ReasoningID = ""
|
st.ReasoningID = ""
|
||||||
st.ReasoningIndex = 0
|
st.ReasoningIndex = 0
|
||||||
st.FuncArgsBuf = make(map[int]*strings.Builder)
|
st.FuncArgsBuf = make(map[string]*strings.Builder)
|
||||||
st.FuncNames = make(map[int]string)
|
st.FuncNames = make(map[string]string)
|
||||||
st.FuncCallIDs = make(map[int]string)
|
st.FuncCallIDs = make(map[string]string)
|
||||||
|
st.FuncOutputIx = make(map[string]int)
|
||||||
|
st.MsgOutputIx = make(map[int]int)
|
||||||
|
st.NextOutputIx = 0
|
||||||
st.MsgItemAdded = make(map[int]bool)
|
st.MsgItemAdded = make(map[int]bool)
|
||||||
st.MsgContentAdded = make(map[int]bool)
|
st.MsgContentAdded = make(map[int]bool)
|
||||||
st.MsgItemDone = make(map[int]bool)
|
st.MsgItemDone = make(map[int]bool)
|
||||||
st.FuncArgsDone = make(map[int]bool)
|
st.FuncArgsDone = make(map[string]bool)
|
||||||
st.FuncItemDone = make(map[int]bool)
|
st.FuncItemDone = make(map[string]bool)
|
||||||
st.PromptTokens = 0
|
st.PromptTokens = 0
|
||||||
st.CachedTokens = 0
|
st.CachedTokens = 0
|
||||||
st.CompletionTokens = 0
|
st.CompletionTokens = 0
|
||||||
st.TotalTokens = 0
|
st.TotalTokens = 0
|
||||||
st.ReasoningTokens = 0
|
st.ReasoningTokens = 0
|
||||||
st.UsageSeen = false
|
st.UsageSeen = false
|
||||||
|
st.CompletionPending = false
|
||||||
|
st.CompletedEmitted = false
|
||||||
// response.created
|
// response.created
|
||||||
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||||
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
|
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
|
||||||
@@ -185,7 +344,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text)
|
outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text)
|
||||||
out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
|
out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
|
||||||
|
|
||||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text})
|
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text, OutputIndex: st.ReasoningIndex})
|
||||||
st.ReasoningID = ""
|
st.ReasoningID = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,10 +360,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
stopReasoning(st.ReasoningBuf.String())
|
stopReasoning(st.ReasoningBuf.String())
|
||||||
st.ReasoningBuf.Reset()
|
st.ReasoningBuf.Reset()
|
||||||
}
|
}
|
||||||
|
if _, exists := st.MsgOutputIx[idx]; !exists {
|
||||||
|
st.MsgOutputIx[idx] = allocOutputIndex()
|
||||||
|
}
|
||||||
|
msgOutputIndex := st.MsgOutputIx[idx]
|
||||||
if !st.MsgItemAdded[idx] {
|
if !st.MsgItemAdded[idx] {
|
||||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
||||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
item, _ = sjson.SetBytes(item, "output_index", msgOutputIndex)
|
||||||
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||||
st.MsgItemAdded[idx] = true
|
st.MsgItemAdded[idx] = true
|
||||||
@@ -213,7 +376,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||||
part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||||
part, _ = sjson.SetBytes(part, "output_index", idx)
|
part, _ = sjson.SetBytes(part, "output_index", msgOutputIndex)
|
||||||
part, _ = sjson.SetBytes(part, "content_index", 0)
|
part, _ = sjson.SetBytes(part, "content_index", 0)
|
||||||
out = append(out, emitRespEvent("response.content_part.added", part))
|
out = append(out, emitRespEvent("response.content_part.added", part))
|
||||||
st.MsgContentAdded[idx] = true
|
st.MsgContentAdded[idx] = true
|
||||||
@@ -222,7 +385,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
|
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
|
||||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||||
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||||
msg, _ = sjson.SetBytes(msg, "output_index", idx)
|
msg, _ = sjson.SetBytes(msg, "output_index", msgOutputIndex)
|
||||||
msg, _ = sjson.SetBytes(msg, "content_index", 0)
|
msg, _ = sjson.SetBytes(msg, "content_index", 0)
|
||||||
msg, _ = sjson.SetBytes(msg, "delta", c.String())
|
msg, _ = sjson.SetBytes(msg, "delta", c.String())
|
||||||
out = append(out, emitRespEvent("response.output_text.delta", msg))
|
out = append(out, emitRespEvent("response.output_text.delta", msg))
|
||||||
@@ -238,10 +401,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
// On first appearance, add reasoning item and part
|
// On first appearance, add reasoning item and part
|
||||||
if st.ReasoningID == "" {
|
if st.ReasoningID == "" {
|
||||||
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
|
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
|
||||||
st.ReasoningIndex = idx
|
st.ReasoningIndex = allocOutputIndex()
|
||||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
|
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
|
||||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
item, _ = sjson.SetBytes(item, "output_index", st.ReasoningIndex)
|
||||||
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID)
|
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID)
|
||||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||||
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||||
@@ -269,6 +432,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
// Before emitting any function events, if a message is open for this index,
|
// Before emitting any function events, if a message is open for this index,
|
||||||
// close its text/content to match Codex expected ordering.
|
// close its text/content to match Codex expected ordering.
|
||||||
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
|
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
|
||||||
|
msgOutputIndex := st.MsgOutputIx[idx]
|
||||||
fullText := ""
|
fullText := ""
|
||||||
if b := st.MsgTextBuf[idx]; b != nil {
|
if b := st.MsgTextBuf[idx]; b != nil {
|
||||||
fullText = b.String()
|
fullText = b.String()
|
||||||
@@ -276,7 +440,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||||
done, _ = sjson.SetBytes(done, "output_index", idx)
|
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||||
@@ -284,74 +448,78 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||||
partDone, _ = sjson.SetBytes(partDone, "output_index", idx)
|
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||||
|
|
||||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
|
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||||
st.MsgItemDone[idx] = true
|
st.MsgItemDone[idx] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only emit item.added once per tool call and preserve call_id across chunks.
|
tcs.ForEach(func(_, tc gjson.Result) bool {
|
||||||
newCallID := tcs.Get("0.id").String()
|
toolIndex := int(tc.Get("index").Int())
|
||||||
nameChunk := tcs.Get("0.function.name").String()
|
key := toolStateKey(idx, toolIndex)
|
||||||
if nameChunk != "" {
|
newCallID := tc.Get("id").String()
|
||||||
st.FuncNames[idx] = nameChunk
|
nameChunk := tc.Get("function.name").String()
|
||||||
}
|
if nameChunk != "" {
|
||||||
existingCallID := st.FuncCallIDs[idx]
|
st.FuncNames[key] = nameChunk
|
||||||
effectiveCallID := existingCallID
|
|
||||||
shouldEmitItem := false
|
|
||||||
if existingCallID == "" && newCallID != "" {
|
|
||||||
// First time seeing a valid call_id for this index
|
|
||||||
effectiveCallID = newCallID
|
|
||||||
st.FuncCallIDs[idx] = newCallID
|
|
||||||
shouldEmitItem = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldEmitItem && effectiveCallID != "" {
|
|
||||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
|
||||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
|
||||||
o, _ = sjson.SetBytes(o, "output_index", idx)
|
|
||||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
|
||||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
|
||||||
name := st.FuncNames[idx]
|
|
||||||
o, _ = sjson.SetBytes(o, "item.name", name)
|
|
||||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure args buffer exists for this index
|
|
||||||
if st.FuncArgsBuf[idx] == nil {
|
|
||||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append arguments delta if available and we have a valid call_id to reference
|
|
||||||
if args := tcs.Get("0.function.arguments"); args.Exists() && args.String() != "" {
|
|
||||||
// Prefer an already known call_id; fall back to newCallID if first time
|
|
||||||
refCallID := st.FuncCallIDs[idx]
|
|
||||||
if refCallID == "" {
|
|
||||||
refCallID = newCallID
|
|
||||||
}
|
}
|
||||||
if refCallID != "" {
|
|
||||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
existingCallID := st.FuncCallIDs[key]
|
||||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
effectiveCallID := existingCallID
|
||||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
shouldEmitItem := false
|
||||||
ad, _ = sjson.SetBytes(ad, "output_index", idx)
|
if existingCallID == "" && newCallID != "" {
|
||||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
effectiveCallID = newCallID
|
||||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
st.FuncCallIDs[key] = newCallID
|
||||||
|
st.FuncOutputIx[key] = allocOutputIndex()
|
||||||
|
shouldEmitItem = true
|
||||||
}
|
}
|
||||||
st.FuncArgsBuf[idx].WriteString(args.String())
|
|
||||||
}
|
if shouldEmitItem && effectiveCallID != "" {
|
||||||
|
outputIndex := st.FuncOutputIx[key]
|
||||||
|
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||||
|
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||||
|
o, _ = sjson.SetBytes(o, "output_index", outputIndex)
|
||||||
|
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||||
|
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||||
|
o, _ = sjson.SetBytes(o, "item.name", st.FuncNames[key])
|
||||||
|
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||||
|
}
|
||||||
|
|
||||||
|
if st.FuncArgsBuf[key] == nil {
|
||||||
|
st.FuncArgsBuf[key] = &strings.Builder{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" {
|
||||||
|
refCallID := st.FuncCallIDs[key]
|
||||||
|
if refCallID == "" {
|
||||||
|
refCallID = newCallID
|
||||||
|
}
|
||||||
|
if refCallID != "" {
|
||||||
|
outputIndex := st.FuncOutputIx[key]
|
||||||
|
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||||
|
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||||
|
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||||
|
ad, _ = sjson.SetBytes(ad, "output_index", outputIndex)
|
||||||
|
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||||
|
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||||
|
}
|
||||||
|
st.FuncArgsBuf[key].WriteString(args.String())
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// finish_reason triggers finalization, including text done/content done/item done,
|
// finish_reason triggers item-level finalization. response.completed is
|
||||||
// reasoning done/part.done, function args done/item done, and completed
|
// deferred until the terminal [DONE] marker so late usage-only chunks can
|
||||||
|
// still populate response.usage.
|
||||||
if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
|
if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
|
||||||
// Emit message done events for all indices that started a message
|
// Emit message done events for all indices that started a message
|
||||||
if len(st.MsgItemAdded) > 0 {
|
if len(st.MsgItemAdded) > 0 {
|
||||||
@@ -360,15 +528,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
for i := range st.MsgItemAdded {
|
for i := range st.MsgItemAdded {
|
||||||
idxs = append(idxs, i)
|
idxs = append(idxs, i)
|
||||||
}
|
}
|
||||||
for i := 0; i < len(idxs); i++ {
|
sort.Slice(idxs, func(i, j int) bool { return st.MsgOutputIx[idxs[i]] < st.MsgOutputIx[idxs[j]] })
|
||||||
for j := i + 1; j < len(idxs); j++ {
|
|
||||||
if idxs[j] < idxs[i] {
|
|
||||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, i := range idxs {
|
for _, i := range idxs {
|
||||||
if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
|
if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
|
||||||
|
msgOutputIndex := st.MsgOutputIx[i]
|
||||||
fullText := ""
|
fullText := ""
|
||||||
if b := st.MsgTextBuf[i]; b != nil {
|
if b := st.MsgTextBuf[i]; b != nil {
|
||||||
fullText = b.String()
|
fullText = b.String()
|
||||||
@@ -376,7 +539,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||||
done, _ = sjson.SetBytes(done, "output_index", i)
|
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||||
@@ -384,14 +547,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||||
partDone, _ = sjson.SetBytes(partDone, "output_index", i)
|
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||||
|
|
||||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||||
@@ -407,192 +570,45 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
|||||||
|
|
||||||
// Emit function call done events for any active function calls
|
// Emit function call done events for any active function calls
|
||||||
if len(st.FuncCallIDs) > 0 {
|
if len(st.FuncCallIDs) > 0 {
|
||||||
idxs := make([]int, 0, len(st.FuncCallIDs))
|
keys := make([]string, 0, len(st.FuncCallIDs))
|
||||||
for i := range st.FuncCallIDs {
|
for key := range st.FuncCallIDs {
|
||||||
idxs = append(idxs, i)
|
keys = append(keys, key)
|
||||||
}
|
}
|
||||||
for i := 0; i < len(idxs); i++ {
|
sort.Slice(keys, func(i, j int) bool {
|
||||||
for j := i + 1; j < len(idxs); j++ {
|
left := st.FuncOutputIx[keys[i]]
|
||||||
if idxs[j] < idxs[i] {
|
right := st.FuncOutputIx[keys[j]]
|
||||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
return left < right || (left == right && keys[i] < keys[j])
|
||||||
}
|
})
|
||||||
}
|
for _, key := range keys {
|
||||||
}
|
callID := st.FuncCallIDs[key]
|
||||||
for _, i := range idxs {
|
if callID == "" || st.FuncItemDone[key] {
|
||||||
callID := st.FuncCallIDs[i]
|
|
||||||
if callID == "" || st.FuncItemDone[i] {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
outputIndex := st.FuncOutputIx[key]
|
||||||
args := "{}"
|
args := "{}"
|
||||||
if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 {
|
if b := st.FuncArgsBuf[key]; b != nil && b.Len() > 0 {
|
||||||
args = b.String()
|
args = b.String()
|
||||||
}
|
}
|
||||||
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
|
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
|
||||||
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
|
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
|
||||||
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
|
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
|
||||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", i)
|
fcDone, _ = sjson.SetBytes(fcDone, "output_index", outputIndex)
|
||||||
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
|
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
|
||||||
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
|
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
|
||||||
|
|
||||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
|
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
itemDone, _ = sjson.SetBytes(itemDone, "output_index", outputIndex)
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
|
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
|
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID)
|
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID)
|
||||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[i])
|
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[key])
|
||||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||||
st.FuncItemDone[i] = true
|
st.FuncItemDone[key] = true
|
||||||
st.FuncArgsDone[i] = true
|
st.FuncArgsDone[key] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
st.CompletionPending = true
|
||||||
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
|
|
||||||
// Inject original request fields into response as per docs/response.completed.json
|
|
||||||
if requestRawJSON != nil {
|
|
||||||
req := gjson.ParseBytes(requestRawJSON)
|
|
||||||
if v := req.Get("instructions"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
|
|
||||||
}
|
|
||||||
if v := req.Get("max_output_tokens"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
|
|
||||||
}
|
|
||||||
if v := req.Get("max_tool_calls"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
|
|
||||||
}
|
|
||||||
if v := req.Get("model"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
|
|
||||||
}
|
|
||||||
if v := req.Get("parallel_tool_calls"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
|
|
||||||
}
|
|
||||||
if v := req.Get("previous_response_id"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
|
|
||||||
}
|
|
||||||
if v := req.Get("prompt_cache_key"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
|
|
||||||
}
|
|
||||||
if v := req.Get("reasoning"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
|
|
||||||
}
|
|
||||||
if v := req.Get("safety_identifier"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
|
|
||||||
}
|
|
||||||
if v := req.Get("service_tier"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
|
|
||||||
}
|
|
||||||
if v := req.Get("store"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
|
|
||||||
}
|
|
||||||
if v := req.Get("temperature"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
|
|
||||||
}
|
|
||||||
if v := req.Get("text"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
|
|
||||||
}
|
|
||||||
if v := req.Get("tool_choice"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
|
|
||||||
}
|
|
||||||
if v := req.Get("tools"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
|
|
||||||
}
|
|
||||||
if v := req.Get("top_logprobs"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
|
|
||||||
}
|
|
||||||
if v := req.Get("top_p"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
|
|
||||||
}
|
|
||||||
if v := req.Get("truncation"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
|
|
||||||
}
|
|
||||||
if v := req.Get("user"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
|
|
||||||
}
|
|
||||||
if v := req.Get("metadata"); v.Exists() {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Build response.output using aggregated buffers
|
|
||||||
outputsWrapper := []byte(`{"arr":[]}`)
|
|
||||||
if len(st.Reasonings) > 0 {
|
|
||||||
for _, r := range st.Reasonings {
|
|
||||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
|
||||||
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
|
|
||||||
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
|
|
||||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Append message items in ascending index order
|
|
||||||
if len(st.MsgItemAdded) > 0 {
|
|
||||||
midxs := make([]int, 0, len(st.MsgItemAdded))
|
|
||||||
for i := range st.MsgItemAdded {
|
|
||||||
midxs = append(midxs, i)
|
|
||||||
}
|
|
||||||
for i := 0; i < len(midxs); i++ {
|
|
||||||
for j := i + 1; j < len(midxs); j++ {
|
|
||||||
if midxs[j] < midxs[i] {
|
|
||||||
midxs[i], midxs[j] = midxs[j], midxs[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, i := range midxs {
|
|
||||||
txt := ""
|
|
||||||
if b := st.MsgTextBuf[i]; b != nil {
|
|
||||||
txt = b.String()
|
|
||||||
}
|
|
||||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
|
||||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
|
||||||
item, _ = sjson.SetBytes(item, "content.0.text", txt)
|
|
||||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(st.FuncArgsBuf) > 0 {
|
|
||||||
idxs := make([]int, 0, len(st.FuncArgsBuf))
|
|
||||||
for i := range st.FuncArgsBuf {
|
|
||||||
idxs = append(idxs, i)
|
|
||||||
}
|
|
||||||
// small-N sort without extra imports
|
|
||||||
for i := 0; i < len(idxs); i++ {
|
|
||||||
for j := i + 1; j < len(idxs); j++ {
|
|
||||||
if idxs[j] < idxs[i] {
|
|
||||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, i := range idxs {
|
|
||||||
args := ""
|
|
||||||
if b := st.FuncArgsBuf[i]; b != nil {
|
|
||||||
args = b.String()
|
|
||||||
}
|
|
||||||
callID := st.FuncCallIDs[i]
|
|
||||||
name := st.FuncNames[i]
|
|
||||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
|
||||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
|
||||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
|
||||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
|
||||||
item, _ = sjson.SetBytes(item, "name", name)
|
|
||||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
|
||||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
|
||||||
}
|
|
||||||
if st.UsageSeen {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
|
|
||||||
if st.ReasoningTokens > 0 {
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
|
|
||||||
}
|
|
||||||
total := st.TotalTokens
|
|
||||||
if total == 0 {
|
|
||||||
total = st.PromptTokens + st.CompletionTokens
|
|
||||||
}
|
|
||||||
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
|
|
||||||
}
|
|
||||||
out = append(out, emitRespEvent("response.completed", completed))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -0,0 +1,423 @@
|
|||||||
|
package responses
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Result) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
lines := strings.Split(string(chunk), "\n")
|
||||||
|
if len(lines) < 2 {
|
||||||
|
t.Fatalf("unexpected SSE chunk: %q", chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||||
|
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||||
|
if !gjson.Valid(dataLine) {
|
||||||
|
t.Fatalf("invalid SSE data JSON: %q", dataLine)
|
||||||
|
}
|
||||||
|
return event, gjson.Parse(dataLine)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_ResponseCompletedWaitsForDone(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in []string
|
||||||
|
doneInputIndex int // Index in tt.in where the terminal [DONE] chunk arrives and response.completed must be emitted.
|
||||||
|
hasUsage bool
|
||||||
|
inputTokens int64
|
||||||
|
outputTokens int64
|
||||||
|
totalTokens int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// A provider may send finish_reason first and only attach usage in a later chunk (e.g. Vertex AI),
|
||||||
|
// so response.completed must wait for [DONE] to include that usage.
|
||||||
|
name: "late usage after finish reason",
|
||||||
|
in: []string{
|
||||||
|
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_late_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
|
||||||
|
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[],"usage":{"prompt_tokens":11,"completion_tokens":7,"total_tokens":18}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
},
|
||||||
|
doneInputIndex: 3,
|
||||||
|
hasUsage: true,
|
||||||
|
inputTokens: 11,
|
||||||
|
outputTokens: 7,
|
||||||
|
totalTokens: 18,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// When usage arrives on the same chunk as finish_reason, we still expect a
|
||||||
|
// single response.completed event and it should remain deferred until [DONE].
|
||||||
|
name: "usage on finish reason chunk",
|
||||||
|
in: []string{
|
||||||
|
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_usage_same_chunk","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":13,"completion_tokens":5,"total_tokens":18}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
},
|
||||||
|
doneInputIndex: 2,
|
||||||
|
hasUsage: true,
|
||||||
|
inputTokens: 13,
|
||||||
|
outputTokens: 5,
|
||||||
|
totalTokens: 18,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// An OpenAI-compatible streams from a buggy server might never send usage, so response.completed should
|
||||||
|
// still wait for [DONE] but omit the usage object entirely.
|
||||||
|
name: "no usage chunk",
|
||||||
|
in: []string{
|
||||||
|
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_no_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
},
|
||||||
|
doneInputIndex: 2,
|
||||||
|
hasUsage: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
completedCount := 0
|
||||||
|
completedInputIndex := -1
|
||||||
|
var completedData gjson.Result
|
||||||
|
|
||||||
|
// Reuse converter state across input lines to simulate one streaming response.
|
||||||
|
var param any
|
||||||
|
|
||||||
|
for i, line := range tt.in {
|
||||||
|
// One upstream chunk can emit multiple downstream SSE events.
|
||||||
|
for _, chunk := range ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m) {
|
||||||
|
event, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||||
|
if event != "response.completed" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
completedCount++
|
||||||
|
completedInputIndex = i
|
||||||
|
completedData = data
|
||||||
|
if i < tt.doneInputIndex {
|
||||||
|
t.Fatalf("unexpected early response.completed on input index %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if completedCount != 1 {
|
||||||
|
t.Fatalf("expected exactly 1 response.completed event, got %d", completedCount)
|
||||||
|
}
|
||||||
|
if completedInputIndex != tt.doneInputIndex {
|
||||||
|
t.Fatalf("expected response.completed on terminal [DONE] chunk at input index %d, got %d", tt.doneInputIndex, completedInputIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Missing upstream usage should stay omitted in the final completed event.
|
||||||
|
if !tt.hasUsage {
|
||||||
|
if completedData.Get("response.usage").Exists() {
|
||||||
|
t.Fatalf("expected response.completed to omit usage when none was provided, got %s", completedData.Get("response.usage").Raw)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// When usage is present, the final response.completed event must preserve the usage values.
|
||||||
|
if got := completedData.Get("response.usage.input_tokens").Int(); got != tt.inputTokens {
|
||||||
|
t.Fatalf("unexpected response.usage.input_tokens: got %d want %d", got, tt.inputTokens)
|
||||||
|
}
|
||||||
|
if got := completedData.Get("response.usage.output_tokens").Int(); got != tt.outputTokens {
|
||||||
|
t.Fatalf("unexpected response.usage.output_tokens: got %d want %d", got, tt.outputTokens)
|
||||||
|
}
|
||||||
|
if got := completedData.Get("response.usage.total_tokens").Int(); got != tt.totalTokens {
|
||||||
|
t.Fatalf("unexpected response.usage.total_tokens: got %d want %d", got, tt.totalTokens)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
|
||||||
|
in := []string{
|
||||||
|
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\",\"limit\":400,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
}
|
||||||
|
|
||||||
|
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||||
|
|
||||||
|
var param any
|
||||||
|
var out [][]byte
|
||||||
|
for _, line := range in {
|
||||||
|
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
addedNames := map[string]string{}
|
||||||
|
doneArgs := map[string]string{}
|
||||||
|
doneNames := map[string]string{}
|
||||||
|
outputItems := map[string]gjson.Result{}
|
||||||
|
|
||||||
|
for _, chunk := range out {
|
||||||
|
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||||
|
switch ev {
|
||||||
|
case "response.output_item.added":
|
||||||
|
if data.Get("item.type").String() != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addedNames[data.Get("item.call_id").String()] = data.Get("item.name").String()
|
||||||
|
case "response.output_item.done":
|
||||||
|
if data.Get("item.type").String() != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID := data.Get("item.call_id").String()
|
||||||
|
doneArgs[callID] = data.Get("item.arguments").String()
|
||||||
|
doneNames[callID] = data.Get("item.name").String()
|
||||||
|
case "response.completed":
|
||||||
|
output := data.Get("response.output")
|
||||||
|
for _, item := range output.Array() {
|
||||||
|
if item.Get("type").String() == "function_call" {
|
||||||
|
outputItems[item.Get("call_id").String()] = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addedNames) != 2 {
|
||||||
|
t.Fatalf("expected 2 function_call added events, got %d", len(addedNames))
|
||||||
|
}
|
||||||
|
if len(doneArgs) != 2 {
|
||||||
|
t.Fatalf("expected 2 function_call done events, got %d", len(doneArgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
if addedNames["call_read"] != "read" {
|
||||||
|
t.Fatalf("unexpected added name for call_read: %q", addedNames["call_read"])
|
||||||
|
}
|
||||||
|
if addedNames["call_glob"] != "glob" {
|
||||||
|
t.Fatalf("unexpected added name for call_glob: %q", addedNames["call_glob"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.Valid(doneArgs["call_read"]) {
|
||||||
|
t.Fatalf("invalid JSON args for call_read: %q", doneArgs["call_read"])
|
||||||
|
}
|
||||||
|
if !gjson.Valid(doneArgs["call_glob"]) {
|
||||||
|
t.Fatalf("invalid JSON args for call_glob: %q", doneArgs["call_glob"])
|
||||||
|
}
|
||||||
|
if strings.Contains(doneArgs["call_read"], "}{") {
|
||||||
|
t.Fatalf("call_read args were concatenated: %q", doneArgs["call_read"])
|
||||||
|
}
|
||||||
|
if strings.Contains(doneArgs["call_glob"], "}{") {
|
||||||
|
t.Fatalf("call_glob args were concatenated: %q", doneArgs["call_glob"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if doneNames["call_read"] != "read" {
|
||||||
|
t.Fatalf("unexpected done name for call_read: %q", doneNames["call_read"])
|
||||||
|
}
|
||||||
|
if doneNames["call_glob"] != "glob" {
|
||||||
|
t.Fatalf("unexpected done name for call_glob: %q", doneNames["call_glob"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := gjson.Get(doneArgs["call_read"], "filePath").String(); got != `C:\repo` {
|
||||||
|
t.Fatalf("unexpected filePath for call_read: %q", got)
|
||||||
|
}
|
||||||
|
if got := gjson.Get(doneArgs["call_glob"], "path").String(); got != `C:\repo` {
|
||||||
|
t.Fatalf("unexpected path for call_glob: %q", got)
|
||||||
|
}
|
||||||
|
if got := gjson.Get(doneArgs["call_glob"], "pattern").String(); got != "*.{yml,yaml}" {
|
||||||
|
t.Fatalf("unexpected pattern for call_glob: %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(outputItems) != 2 {
|
||||||
|
t.Fatalf("expected 2 function_call items in response.output, got %d", len(outputItems))
|
||||||
|
}
|
||||||
|
if outputItems["call_read"].Get("name").String() != "read" {
|
||||||
|
t.Fatalf("unexpected response.output name for call_read: %q", outputItems["call_read"].Get("name").String())
|
||||||
|
}
|
||||||
|
if outputItems["call_glob"].Get("name").String() != "glob" {
|
||||||
|
t.Fatalf("unexpected response.output name for call_glob: %q", outputItems["call_glob"].Get("name").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCallsUseDistinctOutputIndexes(t *testing.T) {
|
||||||
|
in := []string{
|
||||||
|
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
}
|
||||||
|
|
||||||
|
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||||
|
|
||||||
|
var param any
|
||||||
|
var out [][]byte
|
||||||
|
for _, line := range in {
|
||||||
|
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fcEvent struct {
|
||||||
|
outputIndex int64
|
||||||
|
name string
|
||||||
|
arguments string
|
||||||
|
}
|
||||||
|
|
||||||
|
added := map[string]fcEvent{}
|
||||||
|
done := map[string]fcEvent{}
|
||||||
|
|
||||||
|
for _, chunk := range out {
|
||||||
|
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||||
|
switch ev {
|
||||||
|
case "response.output_item.added":
|
||||||
|
if data.Get("item.type").String() != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID := data.Get("item.call_id").String()
|
||||||
|
added[callID] = fcEvent{
|
||||||
|
outputIndex: data.Get("output_index").Int(),
|
||||||
|
name: data.Get("item.name").String(),
|
||||||
|
}
|
||||||
|
case "response.output_item.done":
|
||||||
|
if data.Get("item.type").String() != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID := data.Get("item.call_id").String()
|
||||||
|
done[callID] = fcEvent{
|
||||||
|
outputIndex: data.Get("output_index").Int(),
|
||||||
|
name: data.Get("item.name").String(),
|
||||||
|
arguments: data.Get("item.arguments").String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(added) != 2 {
|
||||||
|
t.Fatalf("expected 2 function_call added events, got %d", len(added))
|
||||||
|
}
|
||||||
|
if len(done) != 2 {
|
||||||
|
t.Fatalf("expected 2 function_call done events, got %d", len(done))
|
||||||
|
}
|
||||||
|
|
||||||
|
if added["call_choice0"].name != "glob" {
|
||||||
|
t.Fatalf("unexpected added name for call_choice0: %q", added["call_choice0"].name)
|
||||||
|
}
|
||||||
|
if added["call_choice1"].name != "read" {
|
||||||
|
t.Fatalf("unexpected added name for call_choice1: %q", added["call_choice1"].name)
|
||||||
|
}
|
||||||
|
if added["call_choice0"].outputIndex == added["call_choice1"].outputIndex {
|
||||||
|
t.Fatalf("expected distinct output indexes for different choices, both got %d", added["call_choice0"].outputIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.Valid(done["call_choice0"].arguments) {
|
||||||
|
t.Fatalf("invalid JSON args for call_choice0: %q", done["call_choice0"].arguments)
|
||||||
|
}
|
||||||
|
if !gjson.Valid(done["call_choice1"].arguments) {
|
||||||
|
t.Fatalf("invalid JSON args for call_choice1: %q", done["call_choice1"].arguments)
|
||||||
|
}
|
||||||
|
if done["call_choice0"].outputIndex == done["call_choice1"].outputIndex {
|
||||||
|
t.Fatalf("expected distinct done output indexes for different choices, both got %d", done["call_choice0"].outputIndex)
|
||||||
|
}
|
||||||
|
if done["call_choice0"].name != "glob" {
|
||||||
|
t.Fatalf("unexpected done name for call_choice0: %q", done["call_choice0"].name)
|
||||||
|
}
|
||||||
|
if done["call_choice1"].name != "read" {
|
||||||
|
t.Fatalf("unexpected done name for call_choice1: %q", done["call_choice1"].name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndToolUseDistinctOutputIndexes(t *testing.T) {
|
||||||
|
in := []string{
|
||||||
|
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
}
|
||||||
|
|
||||||
|
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||||
|
|
||||||
|
var param any
|
||||||
|
var out [][]byte
|
||||||
|
for _, line := range in {
|
||||||
|
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var messageOutputIndex int64 = -1
|
||||||
|
var toolOutputIndex int64 = -1
|
||||||
|
|
||||||
|
for _, chunk := range out {
|
||||||
|
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||||
|
if ev != "response.output_item.added" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch data.Get("item.type").String() {
|
||||||
|
case "message":
|
||||||
|
if data.Get("item.id").String() == "msg_resp_mixed_0" {
|
||||||
|
messageOutputIndex = data.Get("output_index").Int()
|
||||||
|
}
|
||||||
|
case "function_call":
|
||||||
|
if data.Get("item.call_id").String() == "call_choice1" {
|
||||||
|
toolOutputIndex = data.Get("output_index").Int()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if messageOutputIndex < 0 {
|
||||||
|
t.Fatal("did not find message output index")
|
||||||
|
}
|
||||||
|
if toolOutputIndex < 0 {
|
||||||
|
t.Fatal("did not find tool output index")
|
||||||
|
}
|
||||||
|
if messageOutputIndex == toolOutputIndex {
|
||||||
|
t.Fatalf("expected distinct output indexes for message and tool call, both got %d", messageOutputIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneAndCompletedOutputStayAscending(t *testing.T) {
|
||||||
|
in := []string{
|
||||||
|
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||||
|
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
}
|
||||||
|
|
||||||
|
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||||
|
|
||||||
|
var param any
|
||||||
|
var out [][]byte
|
||||||
|
for _, line := range in {
|
||||||
|
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var doneIndexes []int64
|
||||||
|
var completedOrder []string
|
||||||
|
|
||||||
|
for _, chunk := range out {
|
||||||
|
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||||
|
switch ev {
|
||||||
|
case "response.output_item.done":
|
||||||
|
if data.Get("item.type").String() == "function_call" {
|
||||||
|
doneIndexes = append(doneIndexes, data.Get("output_index").Int())
|
||||||
|
}
|
||||||
|
case "response.completed":
|
||||||
|
for _, item := range data.Get("response.output").Array() {
|
||||||
|
if item.Get("type").String() == "function_call" {
|
||||||
|
completedOrder = append(completedOrder, item.Get("call_id").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(doneIndexes) != 2 {
|
||||||
|
t.Fatalf("expected 2 function_call done indexes, got %d", len(doneIndexes))
|
||||||
|
}
|
||||||
|
if doneIndexes[0] >= doneIndexes[1] {
|
||||||
|
t.Fatalf("expected ascending done output indexes, got %v", doneIndexes)
|
||||||
|
}
|
||||||
|
if len(completedOrder) != 2 {
|
||||||
|
t.Fatalf("expected 2 function_call items in completed output, got %d", len(completedOrder))
|
||||||
|
}
|
||||||
|
if completedOrder[0] != "call_glob" || completedOrder[1] != "call_read" {
|
||||||
|
t.Fatalf("unexpected completed function_call order: %v", completedOrder)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -201,6 +201,7 @@ var zhStrings = map[string]string{
|
|||||||
"usage_output": "输出",
|
"usage_output": "输出",
|
||||||
"usage_cached": "缓存",
|
"usage_cached": "缓存",
|
||||||
"usage_reasoning": "思考",
|
"usage_reasoning": "思考",
|
||||||
|
"usage_time": "时间",
|
||||||
|
|
||||||
// ── Logs ──
|
// ── Logs ──
|
||||||
"logs_title": "📋 日志",
|
"logs_title": "📋 日志",
|
||||||
@@ -352,6 +353,7 @@ var enStrings = map[string]string{
|
|||||||
"usage_output": "Output",
|
"usage_output": "Output",
|
||||||
"usage_cached": "Cached",
|
"usage_cached": "Cached",
|
||||||
"usage_reasoning": "Reasoning",
|
"usage_reasoning": "Reasoning",
|
||||||
|
"usage_time": "Time",
|
||||||
|
|
||||||
// ── Logs ──
|
// ── Logs ──
|
||||||
"logs_title": "📋 Logs",
|
"logs_title": "📋 Logs",
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user